import pickle
import os
import torch
import torch.nn as nn
import torch.optim as optim
from tensorboardX import SummaryWriter
from torchvision import transforms
from torchsummary import summary
from hwdb import HWDB
from model import ConvNet
def valid(epoch, net, test_loarder, writer):
print("epoch %d 开始验证..." % epoch)
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loarder:
images, labels = images.cuda(), labels.cuda()
outputs = net(images)
# 取得分最高的那个类
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('correct number: ', correct)
print('totol number:', total)
acc = 100 * correct / total
print('第%d个epoch的识别准确率为:%d%%' % (epoch, acc))
writer.add_scalar('valid_acc', acc, global_step=epoch)
def train(epoch, net, criterion, optimizer, train_loader, writer, save_iter=100):
print("epoch %d 开始训练..." % epoch)
net.train()
sum_loss = 0.0
total = 0
correct = 0
# 数据读取
for i, (inputs, labels) in enumerate(train_loader):
# 梯度清零
optimizer.zero_grad()
if torch.cuda.is_available():
inputs = inputs.cuda()
labels = labels.cuda()
outputs = net(inputs)
loss = criterion(outputs, labels)
# 取得分最高的那个类
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
loss.backward()
optimizer.step()
# 每训练100个batch打印一次平均loss与acc
sum_loss += loss.item()
if (i + 1) % save_iter == 0:
batch_loss = sum_loss / save_iter
# 每跑完一次epoch测试一下准确率
acc = 100 * correct / total
print('epoch: %d, batch: %d loss: %.03f, acc: %.04f'
% (epoch, i + 1, batch_loss, acc))
writer.add_scalar('train_loss', batch_loss, global_step=i + len(train_loader) * epoch)
writer.add_scalar('train_acc', acc, global_step=i + len(train_loader) * epoch)
for name, layer in net.named_parameters():
writer.add_histogram(name + '_grad', layer.grad.cpu().data.numpy(),
global_step=i + len(train_loader) * epoch)
writer.add_histogram(name + '_data', layer.cpu().data.numpy(),
global_step=i + len(train_loader) * epoch)
total = 0
correct = 0
sum_loss = 0.0
if __name__ == "__main__":
# 超参数
epochs = 20
batch_size = 100
lr = 0.01
data_path = r'data'
log_path = r'logs/batch_{}_lr_{}'.format(batch_size, lr)
save_path = r'checkpoints/'
if not os.path.exists(save_path):
os.mkdir(save_path)
# 读取分类类别
with open('char_dict', 'rb') as f:
class_dict = pickle.load(f)
num_classes = len(class_dict)
# 读取数据
transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
])
dataset = HWDB(path=data_path, transform=transform)
print("训练集数据:", dataset.train_size)
print("测试集数据:", dataset.test_size)
trainloader, testloader = dataset.get_loader(batch_size)
net = ConvNet(num_classes)
if torch.cuda.is_available():
net = net.cuda()
# net.load_state_dict(torch.load('checkpoints/handwriting_iter_004.pth'))
print('网络结构:\n')
summary(net, input_size=(3, 64, 64), device='cuda')
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=lr)
writer = SummaryWriter(log_path)
for epoch in range(epochs):
train(epoch, net, criterion, optimizer, trainloader, writer=writer)
valid(epoch, net, testloader, writer=writer)
print("epoch%d 结束, 正在保存模型..." % epoch)
torch.save(net.state_dict(), save_path + 'handwriting_iter_%03d.pth' % epoch)
盈梓的博客
- 粉丝: 9785
- 资源: 2675
最新资源
- Apparat 餐厅,使用 C# 和 SlimDX 制作的开源游戏模拟引擎。.zip
- ANTLR C# 语法
- AppleScript 超薄版,一个超级精简的库,允许你从 mono 项目(从非 MonoMac 项目)执行 AppleScript。.zip
- IMDb 应用程序接口,C# 类,用于从 IMDb 网站获取数据。.zip
- ARSoft.MultiRulePolicyDaemon 反垃圾邮件守护程序
- 应用程序管理库,应用程序管理使您的应用程序生活更轻松。它将自动进行内存管理,处理和记录未处理的异常,分析您的函数,使您的应用程序成为单个实例,并提供 util 函数来获取系统信息。.zip
- Flask图书信息管理系统(python+mysql)源码+数据库(高分项目)
- 已经升级,市面上最多的,7,,8,9,10伺服口罩机通用程序架构,程序已经升级,程序高度模块化,可轻易拓展十几二十多个轴,已经很成功的运用到大量口罩机机器上面去了,plc是目前性价比最高的方案,采用信
- 基于Flask图书信息管理系统(python+mysql)源码+数据库(高分项目)
- 微网双层优化模型matlab 采用yalmip编写三个微网的分层优化模型,考虑电价的负荷响应,综合配电网运营商收益和用户购电成本,程序运行稳定
- 汇编与逆向工程课程相关工具包
- 异步机无感算法解析 提供推导文档,模型,代码…… md500
- 智能大棚农业监测系统的多传感器集成及深度学习分析与远程监控实现
- 智慧旅游期末考察:涵盖智慧旅游架构、数据统计与景区评论挖掘
- 无锡某大厂成熟Foc电机控制 代码,有原理图,用于很多电动车含高端电动自行车厂在用 直接可用,不是一般的普通代码可比的 有上位机用于调试和显示波形,直观调试 代码基于Stm32F030,国产很多
- 知识图谱多领域应用与构建实验课题探讨
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈