import torch
from torch import nn
from torch import optim
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import os
import cv2
import NET
import tqdm
import sys
import unet_2
# 创建一个变换对象,用于将图像缩放为指定大小并居中裁剪
resize_crop = transforms.Compose([
transforms.Resize(256), # 缩放图像,使其短边不小于256像素
transforms.CenterCrop(224) # 居中裁剪图像,使其长和宽都等于224像素
])
# 创建一个变换对象,用于将图像转换为Tensor格式并进行归一化处理
to_tensor_normalize = transforms.Compose([
transforms.ToTensor(), # 将图像转换为Tensor格式
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 归一化处理
])
to_tensor_mask = transforms.Compose([
transforms.Resize(256), # 缩放图像,使其短边不小于256像素
transforms.CenterCrop(224), # 居中裁剪图像,使其长和宽都等于224像素
transforms.ToTensor() # 将图像转换为Tensor格式
])
class MyDataset(Dataset):
def __init__(self,root_dir):
self.root_dir=root_dir
self.img_root=root_dir + r"\last"
self.mask_root = root_dir + r"\last_msk"
self.img_names=os.listdir(self.img_root)
def __len__(self):
return len(self.img_names)
def __getitem__(self, idx):
img_path=os.path.join(self.img_root,self.img_names[idx])
mask_path = os.path.join(self.mask_root, self.img_names[idx])
img=Image.open(img_path)
img = resize_crop(img)
img_tensor = to_tensor_normalize(img)
mask=Image.open(mask_path)
mask_tensor=to_tensor_mask(mask).squeeze()
return img_tensor,mask_tensor
def test_showdir(self):
print(self.root_dir)
def test_showitem(self,idx):
img_path = os.path.join(self.img_root, self.img_names[idx])
mat=cv2.imread(img_path,1)
cv2.imshow("test_showitem",mat)
# label = self.img_names[idx][:-4]
# print("label:",label)
print(self.__getitem__(idx)[0].shape)
print(self.__getitem__(idx)[1].shape)
cv2.waitKey(0)
def train(epochs):
batch_size = 32
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
print('正在使用{}个线程加载数据集'.format(nw))
train_dataset=MyDataset(r"D:\ML_DATA\handbag\train")
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, shuffle=True,
num_workers=nw)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
device_count = torch.cuda.device_count()
print(f"找到 {device_count} 一个CUDA 设备:", end='')
for i in range(device_count):
print(torch.cuda.get_device_name(i))
else:
print("没有找到CUDA设备")
print("device:", device)
MyNet = unet_2.Unet()
MyNet.to(device)
# 将数据移动至显卡(cuda)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(MyNet.parameters(), lr=0.001)
save_path = r'.\UNET.pth'
print("权重文件保存路径:", os.path.join(os.getcwd()))
# epochs = 100
best_acc = 0.0
train_steps = len(train_loader)
MyNet.train()
for epoch in range(epochs):
print(epoch)
running_loss = 0
# train_bar = tqdm(train_loader, file=sys.stdout)
# 此行代码用于将可迭代对象的迭代过程转化为进度条并输出到控制台
for data in train_loader:
images, mask = data
# print(images.shape, mask.shape)
optimizer.zero_grad()
outputs = MyNet(images.to(device)).squeeze()
# print(outputs.shape,mask.shape)
# exit(2)
# mask=mask.squeeze()
loss = loss_function(outputs, mask.to(device))
# 将数据移动至显卡(cuda)
loss.backward()
optimizer.step()
running_loss += loss.item()
# train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
# epochs,
# loss)
torch.save(MyNet.state_dict(), "end.pth")
print(running_loss)
print('Finished Training')
if __name__ == '__main__':
train(1000)
# testdataset=MyDataset(r"D:\ML_DATA\handbag\train")
# testdataset.test_showitem(1)
# img = Image.open(r"D:\ML_DATA\handbag\train\last\0.jpg")
# img = resize_crop(img)
# img = to_tensor_normalize(img)
#
# # 将Tensor图像转换为NumPy数组,并移动数据到CPU上
# img = img.cpu().numpy().transpose((1, 2, 0))
#
# # 显示处理后的图像
# plt.imshow(img)
# plt.show()
.whl
- 粉丝: 3960
- 资源: 4908
最新资源
- 博途S7-1200主站与S7-200从站实现RS485通讯程序 S7-200可以当作一个仪表
- C#、C++分别开发的OPC DA CLIENT软件. 1、枚举服务器名称; 2、连接服务器以后枚举出TAG; 3、根据TAG名称自动读取服务器数据; 4、图片内有OPC SERVER和CLIENT实
- python-workspace.zip.005
- 龙门上下料样本程序,四轴 用台达AS228T和台达触摸屏编写 注意软件是用台达新款软件ISPSOFT ,借鉴价值高,程序有注释
- 一款window下的串口监视抓包工具
- 欧姆龙CP1H与3台力士乐VFC-x610变频器通讯程序 功能:原创程序,可直接用于现场程序 欧姆龙CP1H的CIF11通讯板,实现对3台力士乐VFC-x610变频器 设定频率,控制正反转,读取实际
- dp111113333
- CV-密集人群图像数据集(5800张图片).rar
- 福特汽车主观评价规范,性能开发参考,英文原版直译,评价条目、规则描述非常细致 包含平顺舒适性,转向,操稳,NVH,制动,加速感,驾驶性等等性能,并详细描述了评价的准备工作 评价条目细分至第四级,共
- 三菱FX3S两轴标准程序,XZ两轴,包含轴点动,回零,相对与绝对定位,只要弄明白这个程序,就可以非常了解整个项目的程序如何去编写,从哪里开始下手,可提供程序问题解答,程序流程清晰明了,注释完整
- MATLAB代码:考虑P2G与碳捕集机组的多能微网低碳经济调度 关键词:碳交易 阶梯碳交易 碳捕集 多能微网 低碳调度 仿真平台:MATLAB+yalmip+cplex 主要内容:代码主要做的是一个
- 本程序采用matlab编写,主要是实现电流注入型牛拉法 除此之外,本人还编写了很多种关于潮流计算的程序,主要有牛拉法,前推回代法,以还有相和三相潮流计算程序
- 智能门锁架构图,供大家参考
- 三菱FX3U六轴标准程序,程序包含本体3轴控制,扩展3个1PG定位模块,一共六轴 程序有轴点动控制,回零控制,相对定位,绝对定位 另有气缸数个,一个大是DD马达控制的转盘,整个是转盘多工位流水作业
- 批量登录到远程Linux服务器检查服务器时间差的shell
- MATLAB电动车七自由度整车模型 MATLAB Simulink电动车转弯制动abs模型asr转弯制动防抱死abs模型+模糊控制算法+七自由度整车模型+纵向运动+侧向运动+横摆运动+四轮魔术公式+四
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈