2021SC@SDUSC
以下是PaddleDetection训练部分的代码,这部分代码比较重要,相对来说也比较难,我在反复阅读后还是对他有所了解且在代码的关键部分和难理解的部分加上了备注:
train.py流程解析
从程序入口开始(if name == ‘main’:)
1.直接进入main函数
初始化训练参数:
- ①.parser = ArgsParser() #读取命令行传递参数,加载yaml文件参数
- ②.将参数整合在一起,检查参数配置是否正确
- ③.是否使用GPU加速
- ④.查看paddledet版本是否正确
- ⑤.进入run()函数
配置阶段
- a.系统变量配置、初始化、得到使用GPU数量等
- b.创建数据读取类
- c.创建网络结构类
- d.创建学习率类
- e.创建优化器类
- f.初始化模型权重,加载预训练模型、模型与优化器整合,
- g.是否是多卡,实例多模型并行训练
- 开启训练
- g.遍历数据,开始循环训练,根据时间戳计算一系列时间(剩余时间,平均训练时间)
- h.模型前向推理,反向传播,(多卡模型并行,loss合并)
- j.每个iter结束后输出日志
- k.定期打印log,定期保存模型和优化器参数,(eval开启:最优 and 定时)
- 直到迭代结束
#train.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os, sys
# add python path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
if parent_path not in sys.path:
sys.path.append(parent_path)
# ignore numba warning
import warnings
warnings.filterwarnings('ignore')
import random
import datetime
import time
import numpy as np
import paddle
from paddle.distributed import ParallelEnv
from ppdet.core.workspace import load_config, merge_config, create
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight, save_model
import ppdet.utils.cli as cli
import ppdet.utils.check as check
import ppdet.utils.stats as stats
from ppdet.utils.logger import setup_logger
logger = setup_logger('train')
#运行配置参数解析函数
def parse_args():
parser = cli.ArgsParser()
parser.add_argument(
"--eval",
action='store_true',
default=False,
help="Whether to perform evaluation in train")
parser.add_argument(
"-r", "--resume", default=None, help="weights path for resume")
parser.add_argument(
"--slim_config",
default=None,
type=str,
help="Configuration file of slim method.")
parser.add_argument(
"--enable_ce",
type=bool,
default=False,
help="If set True, enable continuous evaluation job."
"This flag is only used for internal test.")
parser.add_argument(
"--fp16",
action='store_true',
default=False,
help="Enable mixed precision training.")
parser.add_argument(
"--fleet", action='store_true', default=False, help="Use fleet or not")
parser.add_argument(
"--use_vdl",
type=bool,
default=False,
help="whether to record the data to VisualDL.")
parser.add_argument(
'--vdl_log_dir',
type=str,
default="vdl_log_dir/scalar",
help='VisualDL logging directory for scalar.')
parser.add_argument(
'--save_prediction_only',
action='store_true',
default=False,
help='Whether to save the evaluation results only')
args = parser.parse_args()
return args
#run函数,detection套件执行的核心部分
def run(FLAGS, cfg):
# init fleet environment #初始化环境
if cfg.fleet:
init_fleet_env()
else:
# init parallel environment if nranks > 1 #是否采用模型并行(多卡)
init_parallel_env()
if FLAGS.enable_ce: #随机参数
set_random_seed(0)
# build trainer #建立模型
trainer = Trainer(cfg, mode='train')
# load weights #加载预训练模型参数
if FLAGS.resume is not None:
trainer.resume_weights(FLAGS.resume)
elif 'pretrain_weights' in cfg and cfg.pretrain_weights:
trainer.load_weights(cfg.pretrain_weights)
# training #执行训练
trainer.train(FLAGS.eval)
#主函数定义
def main():
FLAGS = parse_args() #加载运行参数
cfg = load_config(FLAGS.config) #加载yaml配置
cfg['fp16'] = FLAGS.fp16 #是否采用半精度
cfg['fleet'] = FLAGS.fleet
cfg['use_vdl'] = FLAGS.use_vdl #是否采用训练可视化
cfg['vdl_log_dir'] = FLAGS.vdl_log_dir #可视化文件路径
cfg['save_prediction_only'] = FLAGS.save_prediction_only #只保存预测结果
merge_config(FLAGS.opt) #合并配置
# 选择执行环境
place = paddle.set_device('gpu' if cfg.use_gpu else 'cpu')
# 是否采用同步BN
if 'norm_type' in cfg and cfg['norm_type'] == 'sync_bn' and not cfg.use_gpu:
cfg['norm_type'] = 'bn'
# slim配置
if FLAGS.slim_config:
cfg = build_slim_model(cfg, FLAGS.slim_config)
#检测配置文件
check.check_config(cfg)
check.check_gpu(cfg.use_gpu)
check.check_version()
#执行run函数
run(FLAGS, cfg)
#程序入口
if __name__ == "__main__":
main()#主函数入口
更多推荐
PaddleDetection代码解析之训练部分解析
发布评论