ConvNeXtv2 pytorch预训练权重转paddle
- 直接上代码
- 使用方法
- 参考链接
直接上代码
import torch
import paddle.fluid as fluid
from collections import OrderedDict
import paddle
import argparse
import os
parser = argparse.ArgumentParser()
parser.add_argument('--torch_weight', type=str, default='convnextv2_atto_1k_224_ema.pt', help='torch_weight path')
parser.add_argument('--paddle_model', type=str, default='convnextv2_atto', help='paddle_model')
parser.add_argument('--paddle_weight_dir', type=str, default='paddle_model', help='paddle_model')
args = parser.parse_args()
if not os.path.exists(args.paddle_weight_dir):
os.makedirs(args.paddle_weight_dir)
def main(args):
# 读取torch的权重文件,我们这里默认读取的是.pt文件
torch_weight = torch.load(args.torch_weight, map_location=torch.device('cpu'))
weight = []
# 对于.pt文件就是这么读取权重
for torch_key in torch_weight['model'].keys():
weight.append([torch_key,torch_weight['model'][torch_key].detach().numpy()])
# print(torch_key)
# print(weight[0])
with fluid.dygraph.guard():
# 加载网络结构
if args.paddle_model == 'convnextv2_atto':
from convnextv2_paddle import convnextv2_atto
paddle_model = convnextv2_atto()
if args.paddle_model == 'convnextv2_femto':
from convnextv2_paddle import convnextv2_femto
paddle_model = convnextv2_femto()
if args.paddle_model == 'convnext_pico':
from convnextv2_paddle import convnext_pico
paddle_model = convnext_pico()
if args.paddle_model == 'convnextv2_nano':
from convnextv2_paddle import convnextv2_nano
paddle_model = convnextv2_nano()
if args.paddle_model == 'convnextv2_tiny':
from convnextv2_paddle import convnextv2_tiny
paddle_model = convnextv2_tiny()
if args.paddle_model == 'convnextv2_base':
from convnextv2_paddle import convnextv2_base
paddle_model = convnextv2_base()
if args.paddle_model == 'convnextv2_large':
from convnextv2_paddle import convnextv2_large
paddle_model = convnextv2_large()
if args.paddle_model == 'convnextv2_huge':
from convnextv2_paddle import convnextv2_huge
paddle_model = convnextv2_huge()
# print(paddle_model)
# 读取paddle网络结构的参数列表
paddle_weight = paddle_model.state_dict()
# # 检查是否paddle中的key在torch的dict中能找到
# for paddle_key in paddle_weight:
# if paddle_key in torch_weight['model'].keys():
# print("Oh Yeah")
# else:
# print("No!!!")
# 进行模型参数转换
new_weight_dict = OrderedDict()
# i = 0
for paddle_key in paddle_weight.keys():
# 首先要确保torch的权重里面有这个key,这样就可以避免DIY模型中一些小模块影响权重转换
if paddle_key in torch_weight['model'].keys():
# pytorch权重和paddle模型的权重为2维时需要转置,其余情况不需要
if len(torch_weight['model'][paddle_key].detach().numpy().shape) == 2:
# print(paddle_key)
new_weight_dict[paddle_key] = torch_weight['model'][paddle_key].detach().numpy().T
else:
new_weight_dict[paddle_key] = torch_weight['model'][paddle_key].detach().numpy()
# i += 1
paddle_model.set_dict(new_weight_dict)
fluid.dygraph.save_dygraph(paddle_model.state_dict(),os.path.join(args.paddle_weight_dir,args.paddle_model))
print('Paddle version: ',paddle.__version__)
print('Torch version: ',torch.__version__)
print(f"You have converted {args.torch_weight} to {os.path.join(args.paddle_weight_dir,args.paddle_model)}.pdparams")
if __name__ == '__main__':
main(args)
# 验证载入的权重
# paddle_weight = paddle.load('conv_ne_xt_v2_0pdparams.pdparams')
# paddle_model2 = convnextv2_atto()
# paddle_model2.set_dict(paddle_weight)
# for i in range(100):
# print(paddle_model.parameters()[i]==paddle_model2.parameters()[i])
使用方法
本代码运行的环境为:
Paddle version: 2.2.2
Torch version: 2.0.0+cpu
考虑到paddle社区一些最新的模型没有预训练权重,结合网上查的资料自己动手写了个pytorch预训练权重转paddle权重的代码。目前在ConvNextv2上进行了实践,使用还是很方便的,只需要导入paddle模型和对应的torch模型的预训练权重。这里convnextv2的paddle代码是根据pytorch代码对齐转换的,所以权重文件中的key都是相同的。之所以要保证key相同,是因为实践过程中,本人发现paddle模型的权重dict中元素的顺序与torch的权重文件并不一致,直接使用有序字典导入会一个都导不进去。然后paddle模型权重中的2维数据,即shape像<160,60>这种,在torch的权重文件中对应的shape应该是<60,160>。因此,代码中加入了判断,来确保权重shape的匹配。
参考链接
paddle复现pytorch踩坑(十一):转换pytorch预训练模型
更多推荐
ConvNeXtv2 pytorch预训练权重转paddle
发布评论