保姆级教程:用PySKL的STGCN++训练Weizmann动作数据集,从视频到模型部署全流程
保姆级教程用PySKL的STGCN训练Weizmann动作数据集从视频到模型部署全流程当你第一次接触骨骼动作识别时可能会被各种复杂的模型和流程吓到。别担心这篇教程将手把手带你完成从视频数据准备到模型部署的全过程。我们选用PySKL框架中的STGCN模型因为它对小型数据集友好且在Weizmann这类经典动作数据集上表现优异。1. 环境准备与数据整理在开始之前确保你的开发环境满足以下基础要求Python 3.7PyTorch 1.8CUDA 11.1如果使用GPU加速至少16GB内存处理视频数据较吃内存1.1 安装PySKL及相关依赖pip install pyskl pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.8.0/index.html常见问题如果遇到mmcv安装失败可以尝试指定版本pip install mmcv-full1.4.01.2 数据集目录结构规范Weizmann数据集通常包含10类动作每类约10个视频。建议按以下结构组织Weizmann/ ├── videos/ │ ├── bend/ │ │ ├── daria_bend.avi │ │ └── ... │ ├── jack/ │ │ ├── denis_jack.avi │ │ └── ... │ └── ... ├── annotations/ └── splits/关键点视频文件名建议包含执行者名称如daria_bend.avi同一类别的视频放在同一子目录下避免文件名中包含空格或特殊字符2. 数据预处理与标注生成STGCN需要骨骼关键点数据作为输入我们需要先将视频转换为模型可识别的格式。2.1 生成骨骼关键点数据使用OpenPose或MMPose提取2D关键点python tools/data/skeleton/pose_estimation.py \ --video-root Weizmann/videos \ --output-root Weizmann/annotations \ --use-mmpose参数说明--video-root视频存放根目录--output-root输出标注文件目录--use-mmpose使用MMPose而非OpenPose推荐2.2 创建数据集划分文件在splits目录下创建train.txt和val.txt示例内容bend/daria_bend jack/denis_jack ...格式要求每行一个样本不带扩展名相对videos目录的路径建议按8:2比例划分训练集和验证集2.3 生成最终数据集文件运行以下命令生成JSON格式的标注文件python tools/data/skeleton/custom_2d_skeleton.py \ --data-path Weizmann/annotations \ --split-path Weizmann/splits \ --out-path Weizmann/annotations/weizmann_2d.pkl3. 模型配置与训练3.1 修改配置文件复制默认配置文件并修改关键参数from mmcv import Config cfg Config.fromfile(configs/skeleton/stgcn/stgcn_2d.py) # 修改关键参数 cfg.data.train.ann_file Weizmann/annotations/weizmann_2d.pkl cfg.data.val.ann_file Weizmann/annotations/weizmann_2d.pkl cfg.data.test.ann_file Weizmann/annotations/weizmann_2d.pkl cfg.model.cls_head.num_classes 10 # Weizmann有10个类别 cfg.total_epochs 50 # 小型数据集可适当增加epoch3.2 开始训练python tools/train.py configs/skeleton/stgcn/stgcn_2d.py \ --work-dir work_dirs/stgcn_weizmann \ --validate训练监控使用TensorBoard查看训练过程tensorboard --logdir work_dirs/stgcn_weizmann关键指标top1_acc验证集准确率3.3 小数据集调优技巧数据增强cfg.data.train.pipeline[3].scale_factor 0.5 # 缩放增强 cfg.data.train.pipeline[4].rot_factor 45 # 旋转增强学习率调整cfg.optimizer.lr 0.01 # 初始学习率 cfg.lr_config.step [20, 40] # 在第20和40epoch时降低学习率早停机制cfg.early_stop dict( monitorval_top1_acc, patience5, modemax )4. 模型测试与部署4.1 模型测试使用最佳模型进行测试python tools/test.py \ configs/skeleton/stgcn/stgcn_2d.py \ work_dirs/stgcn_weizmann/best_top1_acc_epoch_40.pth \ --eval top_k_accuracy4.2 模型导出为ONNXpython tools/deployment/pytorch2onnx.py \ configs/skeleton/stgcn/stgcn_2d.py \ work_dirs/stgcn_weizmann/best_top1_acc_epoch_40.pth \ --output-file stgcn_weizmann.onnx \ --verify4.3 简易部署方案使用Flask创建API服务from flask import Flask, request, jsonify import torch from pyskl.apis import init_recognizer app Flask(__name__) model init_recognizer(configs/skeleton/stgcn/stgcn_2d.py, work_dirs/stgcn_weizmann/best_top1_acc_epoch_40.pth) app.route(/predict, methods[POST]) def predict(): video request.files[video] # 预处理视频并提取关键点 result inference_recognizer(model, video) return jsonify({prediction: result}) if __name__ __main__: app.run(host0.0.0.0, port5000)5. 常见问题排查Q1训练时出现KeyError: total_frames解决方案 确保标注文件生成正确检查视频文件是否可正常读取import cv2 cap cv2.VideoCapture(Weizmann/videos/bend/daria_bend.avi) print(cap.isOpened()) # 应该返回TrueQ2模型准确率始终很低检查清单确认num_classes设置正确检查数据标注质量尝试减小学习率增加数据增强强度Q3GPU内存不足优化方案cfg.data.videos_per_gpu 4 # 减小batch size cfg.optimizer_config.grad_clip dict(max_norm40, norm_type2) # 梯度裁剪在实际项目中我发现最容易出错的环节是数据预处理阶段。特别是当视频格式不统一时会导致关键点提取失败。建议在正式训练前先用少量样本测试整个流程是否畅通。