一些时候我们需要一些基本的网络结构,并且希望它是训练好的,里面的weights都已经训练的差不多,我们只需要继续我们自己的训练就好了,这中情况下,就要学习调用torchvision下训练好的网络结构了。
torchvision涵盖一下几种网络
Classification
- AlexNet
- VGG
- ResNet
- SqueezeNet
- DenseNet
- Inception v3
- GoogLeNet
- ShuffleNet v2
- MobileNet v2
- ResNeXt
- Wide ResNet
- MNASNet
Semantic Segmentation
Object Detection, Instance Segmentation and Person Keypoint Detection
Saving and loading models
保存和读取模型
save and load from pytorch website
基本功能
torch.save: Saves a serialized object to disk. This function uses Python’s pickle utility for serialization. Models, tensors, and dictionaries of all kinds of objects can be saved using this function.
储存序列化对象到本地。 模型,tensors值或者字典都可以用这个函数保存。torch.load: Uses pickle’s unpickling facilities to deserialize pickled object files to memory. This function also facilitates the device to load the data into (see Saving & Loading Model Across Devices).
将用save保存的数据反序列化到内存中。- torch.nn.Module.load_state_dict: Loads a model’s parameter dictionary using a deserialized state_dict. For more information on state_dict, see What is a state_dict?.
这个函数用来读取储存好的字典。
Load Weights Demo
需求
首先你要有一个自己的网络结构,并且类似以上任意一个网络或者两个网络的结合,这里以resnet举例。
我们一般修改网络结构一般会去修改末尾的几层网络,以达到自己的训练目标。
Code
1 | from torchvison import models |
以上就是一个简单的例子了,同样可以把这种加载方式写成一个函数
1 | def load_pretrain_weights(pretrain_dict, selfnet_dict): |