import torch
import argparse
import cv2
import os
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch import nn, optim
from torchvision.transforms import transforms
from unet import Unet
from dataset import LiverDataset
from common_tools import transform_invert
def makedir(dir):
if not os.path.exists(dir):
os.mkdir(dir)
val_interval = 1
# 是否使用cuda
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 均为灰度图像,只需要转换为tensor
x_transforms = transforms.ToTensor()
y_transforms = transforms.ToTensor()
train_curve = list()
valid_curve = list()
def train_model(model, criterion, optimizer, dataload, num_epochs=100):
makedir('./model')
model_path = "./model/weights_20.pth"
if os.path.exists(model_path):
model.load_state_dict(torch.load(model_path, map_location=device))
start_epoch = 20
print('加载成功!')
else:
start_epoch = 0
print('无保存模型,将从头开始训练!')
for epoch in range(start_epoch+1, num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs))
print('-' * 10)
dt_size = len(dataload.dataset)
epoch_loss = 0
step = 0
for x, y in dataload:
step += 1
inputs = x.to(device)
labels = y.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
train_curve.append(loss.item())
print("%d/%d,train_loss:%0.3f" % (step, (dt_size - 1) // dataload.batch_size + 1, loss.item()))
print("epoch %d loss:%0.3f" % (epoch, epoch_loss/step))
if (epoch + 1) % 50 == 0:
torch.save(model.state_dict(), './model/weights_%d.pth' % (epoch + 1))
# Validate the model
valid_dataset = LiverDataset("data/val", transform=x_transforms, target_transform=y_transforms)
valid_loader = DataLoader(valid_dataset, batch_size=4, shuffle=True)
if (epoch + 2) % val_interval == 0:
loss_val = 0.
model.eval()
with torch.no_grad():
step_val = 0
for x, y in valid_loader:
step_val += 1
x = x.type(torch.FloatTensor)
inputs = x.to(device)
labels = y.to(device)
outputs = model(inputs)
loss = criterion(outputs, labels)
loss_val += loss.item()
valid_curve.append(loss_val)
print("epoch %d valid_loss:%0.3f" % (epoch, loss_val / step_val))
train_x = range(len(train_curve))
train_y = train_curve
train_iters = len(dataload)
valid_x = np.arange(1, len(
valid_curve) + 1) * train_iters * val_interval # 由于valid中记录的是EpochLoss,需要对记录点进行转换到iterations
valid_y = valid_curve
plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')
plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()
return model
# 训练模型
def train(args):
model = Unet(1, 1).to(device)
batch_size = args.batch_size
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters())
liver_dataset = LiverDataset("./data/train", transform=x_transforms, target_transform=y_transforms)
dataloaders = DataLoader(liver_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
train_model(model, criterion, optimizer, dataloaders)
# 显示模型的输出结果
def test(args):
model = Unet(1, 1)
model.load_state_dict(torch.load(args.ckpt, map_location='cuda'))
liver_dataset = LiverDataset("data/val", transform=x_transforms, target_transform=y_transforms)
dataloaders = DataLoader(liver_dataset, batch_size=1)
save_root = './data/predict'
model.eval()
plt.ion()
index = 0
with torch.no_grad():
for x, ground in dataloaders:
x = x.type(torch.FloatTensor)
y = model(x)
x = torch.squeeze(x)
x = x.unsqueeze(0)
ground = torch.squeeze(ground)
ground = ground.unsqueeze(0)
img_ground = transform_invert(ground, y_transforms)
img_x = transform_invert(x, x_transforms)
img_y = torch.squeeze(y).numpy()
# cv2.imshow('img', img_y)
src_path = os.path.join(save_root, "predict_%d_s.png" % index)
save_path = os.path.join(save_root, "predict_%d_o.png" % index)
ground_path = os.path.join(save_root, "predict_%d_g.png" % index)
img_ground.save(ground_path)
# img_x.save(src_path)
cv2.imwrite(save_path, img_y * 255)
index = index + 1
# plt.imshow(img_y)
# plt.pause(0.5)
# plt.show()
# 计算Dice系数
def dice_calc(args):
root = './data/predict'
nums = len(os.listdir(root)) // 3
dice = list()
dice_mean = 0
for i in range(nums):
ground_path = os.path.join(root, "predict_%d_g.png" % i)
predict_path = os.path.join(root, "predict_%d_o.png" % i)
img_ground = cv2.imread(ground_path)
img_predict = cv2.imread(predict_path)
intersec = 0
x = 0
y = 0
for w in range(256):
for h in range(256):
intersec += img_ground.item(w, h, 1) * img_predict.item(w, h, 1) / (255 * 255)
x += img_ground.item(w, h, 1) / 255
y += img_predict.item(w, h, 1) / 255
if x + y == 0:
current_dice = 1
else:
current_dice = round(2 * intersec / (x + y), 3)
dice_mean += current_dice
dice.append(current_dice)
dice_mean /= len(dice)
print(dice)
print(round(dice_mean, 3))
if __name__ == '__main__':
#参数解析
parse = argparse.ArgumentParser()
parse.add_argument("--action", type=str, help="train, test or dice", default="train")
parse.add_argument("--batch_size", type=int, default=4)
parse.add_argument("--ckpt", type=str, help="the path of model weight file", default="./model/weights_20.pth")
# parse.add_argument("--ckpt", type=str, help="the path of model weight file")
args = parse.parse_args()
if args.action == "train":
train(args)
elif args.action == "test":
test(args)
elif args.action == "dice":
dice_calc(args)
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
基于pytorch+Unet进行MRI肝脏图像分割源码+数据集.zip 【环境配置】 Python >= 3.7 opencv-python Pillow == 7.0.0 torch == 1.4.0 torchsummary == 1.5.1 torchvision == 0.4.2
资源推荐
资源详情
资源评论
收起资源包目录
基于pytorch+Unet进行MRI肝脏图像分割源码+数据集.zip (1070个子文件)
项目说明.md 172B
Aug_No_485.png 61KB
Aug_No_491.png 61KB
Aug_No_484.png 59KB
Aug_No_493.png 58KB
Aug_No_488.png 58KB
Aug_No_477.png 58KB
Aug_No_486.png 57KB
Aug_No_473.png 57KB
Aug_No_483.png 57KB
Aug_No_481.png 56KB
Aug_No_476.png 56KB
Aug_No_471.png 55KB
Aug_No_490.png 55KB
Aug_No_73.png 54KB
Aug_No_228.png 54KB
Aug_No_343.png 54KB
Aug_No_237.png 54KB
Aug_No_225.png 54KB
Aug_No_339.png 53KB
Aug_No_205.png 53KB
Aug_No_296.png 53KB
Aug_No_212.png 53KB
Aug_No_489.png 53KB
P8_T1_00062.png 53KB
Aug_No_327.png 53KB
Aug_No_246.png 53KB
Aug_No_338.png 53KB
Aug_No_475.png 53KB
Aug_No_482.png 53KB
Aug_No_227.png 53KB
Aug_No_78.png 53KB
Aug_No_478.png 52KB
Aug_No_487.png 52KB
Aug_No_494.png 52KB
Aug_No_342.png 52KB
Aug_No_328.png 52KB
Aug_No_425.png 52KB
Aug_No_347.png 52KB
P8_T1_00070.png 52KB
Aug_No_335.png 52KB
Aug_No_207.png 52KB
Aug_No_224.png 51KB
Aug_No_420.png 51KB
Aug_No_76.png 51KB
Aug_No_492.png 51KB
Aug_No_334.png 51KB
Aug_No_480.png 51KB
Aug_No_77.png 51KB
Aug_No_355.png 51KB
Aug_No_231.png 51KB
Aug_No_240.png 51KB
Aug_No_433.png 51KB
Aug_No_344.png 51KB
Aug_No_226.png 51KB
Aug_No_417.png 51KB
Aug_No_238.png 51KB
Aug_No_249.png 51KB
Aug_No_216.png 51KB
Aug_No_209.png 51KB
Aug_No_333.png 50KB
Aug_No_236.png 50KB
Aug_No_472.png 50KB
Aug_No_269.png 50KB
Aug_No_215.png 50KB
Aug_No_479.png 50KB
Aug_No_235.png 50KB
Aug_No_243.png 50KB
Aug_No_331.png 50KB
Aug_No_256.png 50KB
Aug_No_284.png 50KB
Aug_No_234.png 50KB
Aug_No_429.png 50KB
Aug_No_286.png 50KB
Aug_No_441.png 50KB
Aug_No_210.png 50KB
Aug_No_252.png 50KB
Aug_No_253.png 50KB
P33_T1_00042.png 50KB
Aug_No_496.png 50KB
Aug_No_267.png 50KB
Aug_No_274.png 49KB
Aug_No_332.png 49KB
Aug_No_232.png 49KB
Aug_No_422.png 49KB
Aug_No_301.png 49KB
Aug_No_366.png 49KB
Aug_No_244.png 49KB
Aug_No_341.png 49KB
Aug_No_272.png 49KB
Aug_No_213.png 49KB
Aug_No_474.png 49KB
Aug_No_102.png 49KB
P33_T1_00034.png 49KB
Aug_No_345.png 49KB
Aug_No_348.png 49KB
Aug_No_329.png 49KB
P2_T1_00018.png 49KB
Aug_No_330.png 49KB
Aug_No_497.png 48KB
共 1070 条
- 1
- 2
- 3
- 4
- 5
- 6
- 11
onnx
- 粉丝: 1w+
- 资源: 5627
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
- 19 工资发放明细表-可视化图表.xlsx
- 27 员工工资表(图表分析).xlsx
- 23 财务报告工资数据图表模板.xlsx
- 22 财务报告工资数据图表模板.xlsx
- 24 工资表-年度薪资可视化图表.xlsx
- 26 财务分析部门工资支出图表.xlsx
- Python爬虫技术详解:从基础到实战.zip
- 25 工资费用支出表-可视化图表.xlsx
- 30公司各部门工资支出数据图表1.xlsx
- 29 员工月度工资支出数据图表.xlsx
- 28 工资表(自动计算,图表显示).xlsx
- 31 财务分析工资年度开支图表.xlsx
- 33 年度工资预算表(可视化看板).xlsx
- 32 公司年度工资成本数据图表.xlsx
- 34 年度工资汇总-数据可视化看板.xlsx
- 36 财务报表新年度部门工资预算表.xlsx
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功
- 1
- 2
前往页