import itertools
import os
import random
import argparse
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
print(" -- 使用GPU进行训练 -- ")
def parseArgs():
parser = argparse.ArgumentParser(description='manual to this script')
parser.add_argument("--save_path", type=str, default="./result/")
parser.add_argument("--data_path", type=str, default="../Data/")
parser.add_argument("--origin", type=str, default="GT")
parser.add_argument("--hazy", type=str, default="hazy")
parser.add_argument("--batch_size", type=int, default=4)
args = parser.parse_args()
return args
## 生成器 U-Net(输入照片为256*256) ##
class Generator(nn.Module):
def __init__(self, in_ch, out_ch, ngf=64):
"""
定义生成器的网络结构
:param in_ch: 输入数据的通道数
:param out_ch: 输出数据的通道数
:param ngf: 第一层卷积的通道数 number of generator's first conv filters
"""
super(Generator, self).__init__()
# 下面的激活函数都放在下一个模块的第一步 是为了skip-connect方便
# 左半部分 U-Net encoder
# 每层输入大小折半,从输入图片大小256开始
# 256 * 256(输入)
self.en1 = nn.Sequential(
nn.Conv2d(in_ch, ngf, kernel_size=4, stride=2, padding=1),
# 输入图片已正则化 不需BatchNorm
)
# 128 * 128
self.en2 = nn.Sequential(
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ngf, ngf * 2, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(ngf * 2)
)
# 64 * 64
self.en3 = nn.Sequential(
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ngf * 2, ngf * 4, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(ngf * 4)
)
# 32 * 32
self.en4 = nn.Sequential(
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ngf * 4, ngf * 8, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(ngf * 8)
)
# 16 * 16
self.en5 = nn.Sequential(
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ngf * 8, ngf * 8, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(ngf * 8)
)
# 8 * 8
self.en6 = nn.Sequential(
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ngf * 8, ngf * 8, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(ngf * 8)
)
# 4 * 4
self.en7 = nn.Sequential(
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ngf * 8, ngf * 8, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(ngf * 8)
)
# 2 * 2
self.en8 = nn.Sequential(
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ngf * 8, ngf * 8, kernel_size=4, stride=2, padding=1)
# Encoder输出不用BatchNorm
)
# 右半部分 U-Net decoder
# skip-connect: 前一层的输出+对称的卷积层
# 1 * 1(输入)
self.de1 = nn.Sequential(
nn.ReLU(inplace=True),
nn.ConvTranspose2d(ngf * 8, ngf * 8, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(ngf * 8),
nn.Dropout(p=0.5)
)
# 2 * 2
self.de2 = nn.Sequential(
nn.ReLU(inplace=True),
# skip-connect 所以输入管道数是之前输出的2倍
nn.ConvTranspose2d(ngf * 8 * 2, ngf * 8, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(ngf * 8),
nn.Dropout(p=0.5)
)
# 4 * 4
self.de3 = nn.Sequential(
nn.ReLU(inplace=True),
nn.ConvTranspose2d(ngf * 8 * 2, ngf * 8, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(ngf * 8),
nn.Dropout(p=0.5)
)
# 8 * 8
self.de4 = nn.Sequential(
nn.ReLU(inplace=True),
nn.ConvTranspose2d(ngf * 8 * 2, ngf * 8, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(ngf * 8),
nn.Dropout(p=0.5)
)
# 16 * 16
self.de5 = nn.Sequential(
nn.ReLU(inplace=True),
nn.ConvTranspose2d(ngf * 8 * 2, ngf * 4, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(ngf * 4),
nn.Dropout(p=0.5)
)
# 32 * 32
self.de6 = nn.Sequential(
nn.ReLU(inplace=True),
nn.ConvTranspose2d(ngf * 4 * 2, ngf * 2, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(ngf * 2),
nn.Dropout(p=0.5)
)
# 64 * 64
self.de7 = nn.Sequential(
nn.ReLU(inplace=True),
nn.ConvTranspose2d(ngf * 2 * 2, ngf, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(ngf),
nn.Dropout(p=0.5)
)
# 128 * 128
self.de8 = nn.Sequential(
nn.ReLU(inplace=True),
nn.ConvTranspose2d(ngf * 2, out_ch, kernel_size=4, stride=2, padding=1),
# Encoder输出不用BatchNorm
nn.Tanh()
)
def forward(self, X):
"""
生成器模块前向传播
:param X: 输入生成器的数据
:return: 生成器的输出
"""
# Encoder
en1_out = self.en1(X)
en2_out = self.en2(en1_out)
en3_out = self.en3(en2_out)
en4_out = self.en4(en3_out)
en5_out = self.en5(en4_out)
en6_out = self.en6(en5_out)
en7_out = self.en7(en6_out)
en8_out = self.en8(en7_out)
# Decoder
de1_out = self.de1(en8_out)
de1_cat = torch.cat([de1_out, en7_out], dim=1) # cat by channel
de2_out = self.de2(de1_cat)
de2_cat = torch.cat([de2_out, en6_out], 1)
de3_out = self.de3(de2_cat)
de3_cat = torch.cat([de3_out, en5_out], 1)
de4_out = self.de4(de3_cat)
de4_cat = torch.cat([de4_out, en4_out], 1)
de5_out = self.de5(de4_cat)
de5_cat = torch.cat([de5_out, en3_out], 1)
de6_out = self.de6(de5_cat)
de6_cat = torch.cat([de6_out, en2_out], 1)
de7_out = self.de7(de6_cat)
de7_cat = torch.cat([de7_out, en1_out], 1)
de8_out = self.de8(de7_cat)
return de8_out
## 辨别器 PatchGAN(其实就是卷积网络而已) ##
class Discriminator(nn.Module):
def __init__(self, in_ch, ndf=64):
"""
定义判别器的网络结构
:param in_ch: 输入数据的通道数
:param ndf: 第一层卷积的通道数 number of discriminator's first conv filters
"""
super(Discriminator, self).__init__()
# 不是输出一个表示真假概率的实数,而是一个N*N的Patch矩阵(此处为30*30),其中每一块对应输入数据的一小块
# in_ch + out_ch 是为将对应真假数据同时输入
# 256 * 256(输入)
self.layer1 = nn.Sequential(
nn.Conv2d(in_ch, ndf, kernel_size=4, stride=2, padding=1),
# 输入图片已正则化 不需BatchNorm
nn.LeakyReLU(0.2, inplace=True)
)
# 128 * 128
self.layer2 = nn.Sequential(
nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True)
)
# 64 * 64
self.layer3 = nn.Sequential(
nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True)
)
# 32 * 32
self.layer4 = nn.Sequential(
nn.Conv2d(ndf * 4, ndf * 8
manylinux
- 粉丝: 4657
- 资源: 2492
最新资源
- Fanuc数控系统数据采集与刀具管理源码C#:多线程数据采集、分析记录及虚拟调试支持,focas2 fanuc发那科数控系统数据采集刀具DNC程序管理源代码c# 即送fanuc开发机,无需真机调试
- 永磁同步电机超螺旋滑模控制算法仿真模型:鲁棒性增强,抖振减小,与常规滑模控制对比展示,基于Matlab Simulink搭建的学习参考模型,永磁同步电机超螺旋滑模控制算法仿真模型,有很强的鲁棒性,减小
- 中颖SH79F323系列电动自行车代码方案:霍尔FOC算法详解与应用参考,包含原理图、PCB和代码库资源,坐标变换开源库支持 ,中颖SH79F3231电动自行车代码方案,包含代码,原理图,Pcb,说明
- WinCC嵌入式报表系统:高效数据读取与处理,多种报表模式与实时显示功能,wincc嵌入式报表 一、功能介绍 该报表系统能够读取WINCC中历史归档数据,产生出EXCEL报表文件,同时在画面中EXCE
- 关于TC275、S12X及S32K144的CANoe UDS诊断数据库CDD文件及CAPL Boot上下位机程序移植说明文档,tc275以及s12x以及s32k144的基于canoe的uds诊断数据库
- 永磁同步电机滑模观测器与无传感器控制算法的高效应用研究,永磁同步电机滑模观测器,无传感器控制算法 ,核心关键词:永磁同步电机; 滑模观测器; 无传感器控制算法; 电机控制 ,"无传感器控制算法下的永
- 三菱Q系列PLC大型设备成熟程序案例:可靠、高价值参考与注释详尽的行业典范 #Mitsubishi 三菱电气,三菱Q系PLC大型设备程序 此程序已经实际设备上批量应用,程序成熟可靠,借鉴价值高,程
- 基于工艺参数的LC VCO电感电容压控振荡器设计与介绍,附带中心频率和相位噪声技术指标 ,LC VCO电感电容压控振荡器 LC振荡器 1.有电路文件,带工艺库PDK 2.有设计文档,PDF,原理和仿真
- 基于A星算法融合DWA技术的路径规划:静态与动态避障Matlab源码详解,A星融合DWA的路径规划算法,可实现静态避障碍及动态避障,代码注释详细,matlab源码 ,A星融合DWA; 路径规划算法;
- 风机变桨控制FAST与MATLAB SIMULINK联合仿真模型研究:非线性风力发电机的PID独立与统一变桨控制策略对比,基于NREL 5MW风机参数建模,涉及载荷数据对比分析,风机变桨控制FAST与
- 基于Matlab的CNN卷积神经网络回归预测算法实战教程(MATLAB版本要求高于2018b),CNN 卷积神经网络回归预测算法(基于Matlab实现) 特殊要求:Matlab版本应高于2018b
- 基于Simulink仿真的同步电机死区自适应补偿策略,提高系统性能并优化零点电流噪声(MATLAB 2018版),simulink仿真模型,同步电机死区补偿,自适应补偿,图一前面开了补偿,后面关了补偿
- 基于MATLAB+CPLEX平台的电转气协同碳捕集垃圾焚烧虚拟电厂优化调度程序复现与运行结果展示,MATLAB代码:计及电转气协同的含碳捕集与垃圾焚烧电厂优化调度 关键词:碳捕集 电厂 需求响应 优
- 基于TMS32F2808的50kw组串式三相光伏并网逆变器方案,50kw组串式 三相光伏并网逆变器方案 主控TMS32F2808,提供pcb,原理图,代码,如下: 1)主控DSP板, 负责逆变器的
- 精确时间同步系统的以太网数据包处理模块及cocotb测试平台代码实现与验证,用于1G,10G和25G数据包处理的以太网以及IP,UDP,ARP的模块以及实现需要精确时间同步系统的各种PTP组件,包含c
- IcsRade.Lot物联网框架:毫秒级控制,非轮询远程数据采集与高效控制管理源码解库,IcsRade.Lot物联网框架C#源码 该框架自带集成mqtt服务器,Modbus Rtu及Modbus T
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
- 1
- 2
- 3
前往页