# 3D-UNet model.
# x: 128x128 resolution for 32 frames.
# https://github.com/huangzhii/FCN-3D-pytorch/blob/master/main3d.py
import torch
import torch.nn as nn
import os
import numpy as np
from collections import OrderedDict
def passthrough(x, **kwargs):
return x
def ELUCons(elu, nchan):
if elu:
return nn.ELU(inplace=True)
else:
return nn.PReLU(nchan)
class LUConv(nn.Module):
def __init__(self, nchan, elu):
super(LUConv, self).__init__()
self.relu1 = ELUCons(elu, nchan)
self.conv1 = nn.Conv3d(nchan, nchan, kernel_size=5, padding=2)
self.bn1 = torch.nn.BatchNorm3d(nchan)
def forward(self, x):
out = self.relu1(self.bn1(self.conv1(x)))
return out
def _make_nConv(nchan, depth, elu):
layers = []
for _ in range(depth):
layers.append(LUConv(nchan, elu))
return nn.Sequential(*layers)
class InputTransition(nn.Module):
def __init__(self, in_channels, elu):
super(InputTransition, self).__init__()
self.num_features = 16
self.in_channels = in_channels
self.conv1 = nn.Conv3d(self.in_channels, self.num_features, kernel_size=5, padding=2)
self.bn1 = torch.nn.BatchNorm3d(self.num_features)
self.relu1 = ELUCons(elu, self.num_features)
def forward(self, x):
out = self.conv1(x)
repeat_rate = int(self.num_features / self.in_channels)
out = self.bn1(out)
x16 = x.repeat(1, repeat_rate, 1, 1, 1)
return self.relu1(torch.add(out, x16))
class DownTransition(nn.Module):
def __init__(self, inChans, nConvs, elu, dropout=False):
super(DownTransition, self).__init__()
outChans = 2 * inChans
self.down_conv = nn.Conv3d(inChans, outChans, kernel_size=2, stride=2)
self.bn1 = torch.nn.BatchNorm3d(outChans)
self.do1 = passthrough
self.relu1 = ELUCons(elu, outChans)
self.relu2 = ELUCons(elu, outChans)
if dropout:
self.do1 = nn.Dropout3d()
self.ops = _make_nConv(outChans, nConvs, elu)
def forward(self, x):
down = self.relu1(self.bn1(self.down_conv(x)))
out = self.do1(down)
out = self.ops(out)
out = self.relu2(torch.add(out, down))
return out
class UpTransition(nn.Module):
def __init__(self, inChans, outChans, nConvs, elu, dropout=False):
super(UpTransition, self).__init__()
self.up_conv = nn.ConvTranspose3d(inChans, outChans // 2, kernel_size=2, stride=2)
self.bn1 = torch.nn.BatchNorm3d(outChans // 2)
self.do1 = passthrough
self.do2 = nn.Dropout3d()
self.relu1 = ELUCons(elu, outChans // 2)
self.relu2 = ELUCons(elu, outChans)
if dropout:
self.do1 = nn.Dropout3d()
self.ops = _make_nConv(outChans, nConvs, elu)
def forward(self, x, skipx):
out = self.do1(x)
skipxdo = self.do2(skipx)
out = self.relu1(self.bn1(self.up_conv(out)))
xcat = torch.cat((out, skipxdo), 1)
out = self.ops(xcat)
out = self.relu2(torch.add(out, xcat))
return out
class OutputTransition(nn.Module):
def __init__(self, in_channels, classes, elu):
super(OutputTransition, self).__init__()
self.classes = classes
self.conv1 = nn.Conv3d(in_channels, classes, kernel_size=5, padding=2)
self.bn1 = torch.nn.BatchNorm3d(classes)
self.conv2 = nn.Conv3d(classes, classes, kernel_size=1)
self.relu1 = ELUCons(elu, classes)
def forward(self, x):
# convolve 32 down to channels as the desired classes
out = self.relu1(self.bn1(self.conv1(x)))
out = self.conv2(out)
return out
class VNet(nn.Module):
"""
Implementations based on the Vnet paper: https://arxiv.org/abs/1606.04797
"""
def __init__(self, elu=True, in_channels=1, classes=1):
super(VNet, self).__init__()
self.classes = classes
self.in_channels = in_channels
self.in_tr = InputTransition(in_channels, elu=elu)
self.down_tr32 = DownTransition(16, 1, elu)
self.down_tr64 = DownTransition(32, 2, elu)
self.down_tr128 = DownTransition(64, 3, elu, dropout=False)
self.down_tr256 = DownTransition(128, 2, elu, dropout=False)
self.up_tr256 = UpTransition(256, 256, 2, elu, dropout=False)
self.up_tr128 = UpTransition(256, 128, 2, elu, dropout=False)
self.up_tr64 = UpTransition(128, 64, 1, elu)
self.up_tr32 = UpTransition(64, 32, 1, elu)
self.out_tr = OutputTransition(32, classes, elu)
def forward(self, x):
out16 = self.in_tr(x)
out32 = self.down_tr32(out16)
out64 = self.down_tr64(out32)
out128 = self.down_tr128(out64)
out256 = self.down_tr256(out128)
out = self.up_tr256(out256, out128)
out = self.up_tr128(out, out64)
out = self.up_tr64(out, out32)
out = self.up_tr32(out, out16)
out = self.out_tr(out)
return out
if __name__ == '__main__':
x = torch.randn(1, 1, 64, 64, 64)
net = VNet(in_channels=1, classes=6)
y = net(x)
print(y.shape())
onnx
- 粉丝: 1w+
- 资源: 5627
最新资源
- 永磁同步电机(pmsm)模型预测控制(MPC)matla b simulink仿真模型,有PI矢量控制,直接预测控制(有限集模型预测控制)(这个其中包括做了单矢量和双矢量或者可以成为三矢量的有限集预测
- Google Chrome浏览器ChromeDriver驱动下载(Chrome版本:132.0.6834.84)win64
- Google Chrome浏览器ChromeDriver驱动下载(Chrome版本:132.0.6834.84)win32
- 从0到1搭建推荐系统 - 数据驱动的算法与架构设计(带数据集)
- 汇川H3U标准程序,程序有本体脉冲控制的三轴定位,有总线控制的汇川伺服定位,轴点动,回零,相对定位绝对定位,程序结构清晰,分模块控制,是工控者学习的好案例
- 从0到1搭建推荐系统 - 数据驱动的算法与架构设计(带数据集)
- S7-200Smart 恒压供水程序样例+485通讯样例
- 基于simulink三自由度汽车操纵模型(侧向,侧倾,横摆)带数据参数,有详细公式文档 具有特殊性,发出不 哦(高于或等于MATLAB 2016a版本的都可打开模型)
- C++编写,qt框架,windows串口调试助手,多线程运行,性能好,效率高,不丢数据,保证代码质量
- 从0到1搭建推荐系统 - 数据驱动的算法与架构设计(带数据集)
- . NET C# WPF图书管理系统源码 .net C# WPF图书管理系统源码 自己开发,纯源码 主要技术:C#、基于wpf开发、sql server数据库的增删改查 源码特点:代码完整规范,采
- 西门子Smart200和台达ⅤFD一M系列变频器通讯程序 Smart和三菱E700或D7O0变频器通讯程序,程序带注释,包括接线图纸,变频器参数设置,全都有,拿到即可以用,节约开发时间
- ofdm 水声通信 qpsk fpga
- COMSOL手性超材料文献模拟模型 计算左右旋圆偏振下的吸收、反射、透射率(材料参数未与文献一致 趋势吻合)
- 昆仑通态MCGS与3台力士乐VFC-x610变频器通讯程序 实现昆仑通态触摸屏与3台力士乐VFC-x610变频器通讯,程序稳定可靠 器件:昆仑通态TPC7062KD触摸屏,3台力士乐VFC-x610变
- Unity3d 基于UniStorm插件和xx天气API实现实时天气系统源码工程
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈