MindSpore从大型模型中导出模块以及权重处理
introduction
在学习Mindformers上的大模型时我们总会出现一些需求:
- 我们可能只需要其中的部分模块。
- 我们不想要原始封装方式中的那么多冗余的函数之类。
- 我们自己构建大模型但一些模块我们并不想改变,或者说我们想要将多个大模型中的模块进行结合。
那么我们就需要进行模型权重提取。
MindSpore对权重的操作
查看哪些参数是可训练
执行如下代码即可:
1 2 3 4 5
| import Transformer model = Transformer() print(model.trainable_params())
Parameter (name=visual.visual.class_embedding, shape=(768,), dtype=Float32, requires_grad=True)
|
根据上述打印结果我们可以发现每个可训练权重由如下几个部分组成:
- name
- shape
- dtype
- requires_grad
因此我们可以通过如下方式对权重数据进行查看:
1 2 3 4 5
| for param in net.trainable_params(): print(param.name) print(param.shape) print(param.dtype) print(param.requires_grad)
|
将权重冻结
- 冻结某些参数,例如第一层卷积层的权重
1 2 3
| for param in net.trainable_params(): if param.name == "conv1.weight": param.requires_grad = False
|
- 冻结所有参数,例如第一层卷积层的权重
1 2
| for param in net.trainable_params(): param.requires_grad = False
|
储存权重文件
MindSpore权重文件的存储遵从以下原则,其本身必须是nn.cell封装好的,也就是说要有construct函数等,而我们调用的大部分网络其实已经封装完毕正常存贮即可,而Parameter参数由于只是一个可学习的参数没有用nn.cell封装因此要是单独存贮必须要改变其存储方式,使用方法3即可。
- 正常储存:
1 2
| from mindspore.train.serialization import save_checkpoint save_checkpoint(visual_t, "clip_visual.ckpt")
|
- 改变权重名称储存:(根据上面的查看结果我们其实可以发现权重文件本质就是一个字典因此只需新构建一个字典即可)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
| import mindspore as ms import mindspore.nn as nn
class MyModel(nn.Cell): def __init__(self): super(MyModel, self).__init__() self.fc1 = nn.Dense(10, 20) self.fc2 = nn.Dense(20, 30)
model = MyModel()
new_state_dict = {} for param in model.trainable_params(): name = param.name if 'fc1' in name: new_name = name.replace('fc1', 'new_fc1') else: new_name = name new_state_dict[new_name] = param
ms.save_checkpoint(new_state_dict, 'renamed_weights.ckpt')
|
- Parameter权重的保存:
1 2 3 4 5
| import mindspore as ms import numpy as np
weight_np = model.weight.asnumpy() np.save('weights.npy', weight_np)
|
加载权重到模型中
- 正常加载:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
| import mindspore.nn as nn from mindspore import load_checkpoint, load_param_into_net
class YourModel(nn.Cell): def __init__(self): super(YourModel, self).__init__()
def construct(self, x): pass
model = YourModel()
param_dict = load_checkpoint("your_model.ckpt")
load_param_into_net(model, param_dict)
|
加载完成会返回两个list
- param_not_load (List),网络中没有被加载的参数。
- ckpt_not_load (List),checkpoint文件中没有被加载的参数。
也就是说如果正常的话应该是两个空list
- 但是我们在提取模块权重并将其加载到新网络中时很容易出现权重名称不对加载失败的情况。这个时候就结合上面提到的直接对权重名称进行修改即可正常使用。
- Parameter存储的npy文件的加载
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
| import mindspore as ms import numpy as np
class MyModel(ms.nn.Cell): def __init__(self): super(MyModel, self).__init__() self.weight = ms.Parameter(ms.Tensor([0.0, 0.0, 0.0]))
model = MyModel()
weight_np = np.load('weights.npy')
model.weight.set_data(ms.Tensor(weight_np))
print(model.weight)
|