import numpy as np
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
from pathlib import Path
import visdom
from constants import settings
from constants.enum_keys import HK
from models.pose_estimation_model import PoseEstimationModel
from models.pafs_network import PAFsLoss
from aichallenger import AicNorm
from imgaug.augmentables.heatmaps import HeatmapsOnImage
class Trainer:
def __init__(self, batch_size, is_unittest=False):
# self.debug_mode: Set num_worker to 0. Otherwise pycharm debug won't work due to multithreading.
self.is_unittest = is_unittest
self.epochs = 5
self.val_step = 500
self.batch_size = batch_size
self.vis = visdom.Visdom()
self.img_key = HK.NORM_IMAGE
self.pcm_key = HK.PCM_ALL
self.paf_key = HK.PAF_ALL
if torch.cuda.is_available():
print("GPU available.")
else:
print("GPU not available. running with CPU.")
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.model_pose = PoseEstimationModel()
self.model_optimizer = optim.Adam(self.model_pose.parameters(), lr=1e-3)
self.loss = PAFsLoss()
train_dataset = AicNorm(Path.home() / "AI_challenger_keypoint", is_train=True,
resize_img_size=(512, 512), heat_size=(64, 64))
self.train_loader = DataLoader(train_dataset, self.batch_size, shuffle=True, num_workers=settings.num_workers,
pin_memory=True, drop_last=True)
# test_dataset = AicNorm(Path.home() / "AI_challenger_keypoint", is_train=False,
# resize_img_size=(512, 512), heat_size=(64, 64))
# self.val_loader = DataLoader(test_dataset, self.batch_size, shuffle=False, num_workers=workers, pin_memory=True,
# drop_last=True)
# self.val_iter = iter(self.val_loader)
def set_train(self):
"""Convert models to training mode
"""
self.model_pose.train()
def set_eval(self):
"""Convert models to testing/evaluation mode
"""
self.model_pose.eval()
def train(self):
self.epoch = 0
self.step = 0
self.model_pose.load_ckpt()
for self.epoch in range(self.epochs):
print("Epoch:{}".format(self.epoch))
self.run_epoch()
if self.is_unittest:
break
def run_epoch(self):
self.set_train()
for batch_idx, inputs in enumerate(self.train_loader):
inputs[self.img_key] = inputs[self.img_key].to(self.device, dtype=torch.float32)
inputs[self.pcm_key] = inputs[self.pcm_key].to(self.device, dtype=torch.float32)
inputs[self.paf_key] = inputs[self.paf_key].to(self.device, dtype=torch.float32)
loss, b1_out, b2_out = self.process_batch(inputs)
self.model_optimizer.zero_grad()
loss.backward()
self.model_optimizer.step()
# Clear visdom environment
if self.step == 0:
self.vis.close()
# Validate
if self.step % self.val_step == 0 and self.step != 0:
# self.val()
self.model_pose.save_ckpt()
if self.step % 50 == 0:
print("step {}; Loss {}".format(self.step, loss.item()))
# Show training materials
if self.step % self.val_step == 0:
pcm_CHW = b1_out[0].cpu().detach().numpy()
paf_CHW = b2_out[0].cpu().detach().numpy()
img_CHW = inputs[self.img_key][0].cpu().detach().numpy()[::-1, ...]
pred_pcm_amax = np.amax(pcm_CHW, axis=0) # HW
gt_pcm_amax = np.amax(inputs[self.pcm_key][0].cpu().detach().numpy(), axis=0)
pred_paf_amax = np.amax(paf_CHW, axis=0)
gt_paf_amax = np.amax(inputs[self.paf_key][0].cpu().detach().numpy(), axis=0)
self.vis.image(img_CHW, win="Input", opts={'title': "Input"})
self.vis.heatmap(np.flipud(pred_pcm_amax), win="Pred-PCM", opts={'title': "Pred-PCM"})
self.vis.heatmap(np.flipud(gt_pcm_amax), win="GT-PCM", opts={'title': "GT-PCM"})
self.vis.heatmap(np.flipud(pred_paf_amax), win="Pred-PAF", opts={'title': "Pred-PAF"})
self.vis.heatmap(np.flipud(gt_paf_amax), win="GT-PAF", opts={'title': "GT-PAF"})
self.vis.line(X=np.array([self.step]), Y=loss.cpu().detach().numpy()[np.newaxis], win='Loss', update='append')
if self.is_unittest:
break
self.step += 1
# def val(self, ):
# """Validate the model on a single minibatch
# """
# self.set_eval()
#
# try:
# inputs = next(self.val_iter)
# except StopIteration:
# self.val_iter = iter(self.val_loader)
# inputs = next(self.val_iter)
#
# inputs_gpu = {self.img_key: inputs[self.img_key].to(self.device, dtype=torch.float32),
# self.pcm_key: inputs[self.pcm_key].to(self.device, dtype=torch.float32),
# self.paf_key: inputs[self.paf_key].to(self.device, dtype=torch.float32)}
# with torch.no_grad():
# loss, b1_out, b2_out = self.process_batch(inputs_gpu)
# pred_pcm_amax = np.amax(b1_out[0].cpu().numpy(), axis=0) # HW
# gt_pcm_amax = np.amax(inputs[self.pcm_key][0].cpu().numpy(), axis=0)
# pred_paf_amax = np.amax(b2_out[0].cpu().numpy(), axis=0)
# gt_paf_amax = np.amax(inputs[self.paf_key][0].cpu().numpy(), axis=0)
# # Image augmentation disabled due to pred phase
# self.vis.image(inputs[self.img_key][0].cpu().numpy()[::-1, ...], win="Input", opts={'title': "Input"})
# self.vis.heatmap(np.flipud(pred_pcm_amax), win="Pred-PCM", opts={'title': "Pred-PCM"})
# self.vis.heatmap(np.flipud(gt_pcm_amax), win="GT-PCM", opts={'title': "GT-PCM"})
# self.vis.heatmap(np.flipud(pred_paf_amax), win="Pred-PAF", opts={'title': "Pred-PAF"})
# self.vis.heatmap(np.flipud(gt_paf_amax), win="GT-PAF", opts={'title': "GT-PAF"})
# self.vis.line(X=np.array([self.step]), Y=loss.cpu().numpy()[np.newaxis], win='Loss', update='append')
# self.set_train()
def process_batch(self, inputs):
"""Pass a minibatch through the network and generate images and losses
"""
res = self.model_pose(inputs[self.img_key])
gt_pcm = inputs[self.pcm_key] # ["heatmap"]: {"vis_or_not": NJHW, "visible": NJHW}
gt_pcm = gt_pcm.unsqueeze(1)
gt_paf = inputs[self.paf_key]
gt_paf = gt_paf.unsqueeze(1)
b1_stack = torch.stack(res[HK.B1_SUPERVISION], dim=1) # Shape (N, Stage, C, H, W)
b2_stack = torch.stack(res[HK.B2_SUPERVISION], dim=1)
loss = self.loss(b1_stack, b2_stack, gt_pcm, gt_paf)
return loss, res[HK.B1_OUT], res[HK.B2_OUT]
没有合适的资源?快使用搜索试试~ 我知道了~
温馨提示
基于深度学习pytorch框架的交通警察指挥手势识别项目源码+训练好的模型+数据集+项目操作说明.zip 识别8种中国交通警察指挥手势的Pytorch深度学习项目 带训练好的模型以及数据集 下载模型参数文件checkpoint和生成的骨架generated 放置在: ctpgr-pytorch/checkpoints ctpgr-pytorch/generated 下载交警手势数据集(必选) 交警手势数据集下载: 放置在: (用户文件夹)/PoliceGestureLong (用户文件夹)/AI_challenger_keypoint # 用户文件夹 在 Windows下是'C:\Users\(用户名)',在Linux下是 '/home/(用户名)' 安装Pytorch和其它依赖: # Python 3.8.5 conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch conda install ujson pip install visdom opencv-python imgaug 数据集和
资源推荐
资源详情
资源评论
收起资源包目录
基于深度学习pytorch框架的交通警察指挥手势识别项目源码+训练好的模型+数据集+项目操作说明.zip (33个子文件)
basic_tests
basic_tests.py 3KB
train
train_police_gesture_model.py 3KB
train_keypoint_model.py 7KB
pred
evaluation.py 4KB
human_keypoint_pred.py 2KB
play_gesture_results.py 4KB
play_keypoint_results.py 1KB
gesture_pred.py 2KB
prepare_skeleton_from_video.py 587B
docs
intro.gif 4.46MB
aichallenger
s1_resize.py 4KB
__init__.py 215B
s4_affinity_field.py 3KB
defines.py 297B
s3_gaussian.py 3KB
s2_augment.py 3KB
visual_debug.py 1KB
s5_norm.py 1KB
s0_native.py 3KB
pgdataset
s0_label.py 1KB
__init__.py 0B
s1_skeleton.py 4KB
s3_handcraft.py 3KB
s2_truncate.py 1KB
项目操作说明.md 1KB
ctpgr.py 3KB
models
gesture_recognition_model.py 2KB
pafs_resnet.py 1KB
pafs_network.py 7KB
pose_estimation_model.py 1KB
constants
enum_keys.py 2KB
settings.py 220B
keypoints.py 900B
共 33 条
- 1
onnx
- 粉丝: 1w+
- 资源: 5627
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
- 案例分析:研发人员绩效和薪酬管理的困境.doc
- 企业中薪酬管理存在的问题分析及对策.doc
- 员工年度薪酬收入结构分析报告.doc
- 薪酬分析报告.docx
- 西门子S7-1200控制四轴伺服程序案例: 1.内容涵盖伺服,步进点动,回原,相对定位,绝对定位,速度模式控制 特别适合学习伺服和步进的朋友们 PTO伺服轴脉冲定位控制+速度模式控制+扭矩模式; 2
- 企业公司薪酬保密协议.doc
- 薪酬保密制度 (1).docx
- 薪酬保密管理规定制度.doc
- 薪酬保密制度.docx
- 薪酬保密协议书.docx
- 薪酬保密承诺书.docx
- 薪酬管理制度.doc
- 员工工资薪酬保密协议.docx
- 员工工资保密暂行管理条例.docx
- 员工薪酬保密协议.doc
- 1Redis基础认识与安装.html
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功
- 1
- 2
前往页