SSS

Be honest with yourself.


  • Home

  • About

  • Tags

  • Categories

  • Archives

pytorch load weight

Posted on 2019-08-28 In deep learning

一些时候我们需要一些基本的网络结构,并且希望它是训练好的,里面的weights都已经训练的差不多,我们只需要继续我们自己的训练就好了,这中情况下,就要学习调用torchvision下训练好的网络结构了。

torchvision.models相关信息

torchvision涵盖一下几种网络

Classification

  • AlexNet
  • VGG
  • ResNet
  • SqueezeNet
  • DenseNet
  • Inception v3
  • GoogLeNet
  • ShuffleNet v2
  • MobileNet v2
  • ResNeXt
  • Wide ResNet
  • MNASNet

Semantic Segmentation

  • FCN ResNet101
  • DeepLabV3 ResNet101

Object Detection, Instance Segmentation and Person Keypoint Detection

  • Faster R-CNN ResNet-50 FPN
  • Mask R-CNN ResNet-50 FPN

Saving and loading models
保存和读取模型

save and load from pytorch website

基本功能

  1. torch.save: Saves a serialized object to disk. This function uses Python’s pickle utility for serialization. Models, tensors, and dictionaries of all kinds of objects can be saved using this function.
    储存序列化对象到本地。 模型,tensors值或者字典都可以用这个函数保存。

  2. torch.load: Uses pickle’s unpickling facilities to deserialize pickled object files to memory. This function also facilitates the device to load the data into (see Saving & Loading Model Across Devices).
    将用save保存的数据反序列化到内存中。

  3. torch.nn.Module.load_state_dict: Loads a model’s parameter dictionary using a deserialized state_dict. For more information on state_dict, see What is a state_dict?.
    这个函数用来读取储存好的字典。

Load Weights Demo

需求

首先你要有一个自己的网络结构,并且类似以上任意一个网络或者两个网络的结合,这里以resnet举例。

我们一般修改网络结构一般会去修改末尾的几层网络,以达到自己的训练目标。

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from torchvison import models
# 首先提取torchvision.models下的网络。
resnet = models.resnet50(pretrain=True)
# 网络的权重weights是储存在一个python dictionary字典里的
weight_state_dict = resnet.state_dict()
# 此时我们就获取了这个含有weights的字典

#假设我们自己的网络是
resnet_self = resnet50_self()
# 同样的我们获取自己网络的weights字典
weight_self_state_dict = resnet_self.state_dict()

# 这是我们就可以根据字典的keys来匹配对应的weights赋值了
for key in weight_state_dict.keys():
print(key) # 检查key的名称
# 这里检查torchvision.models下权重的key有没有存在于自己网络里
if key in weight_self_state_dict.keys():
weight_self_state_dict[key] = weight_state_dict[key]
print('load weight finished')

# load对应的权重字典
resnet.load_state_dict(weight_self_state_dict)

以上就是一个简单的例子了,同样可以把这种加载方式写成一个函数

1
2
3
4
5
6
7
8
9
def load_pretrain_weights(pretrain_dict, selfnet_dict):
for key in pretrain_dict.keys():
print(key) # 检查key的名称
# 这里检查torchvision.models下权重的key有没有存在于自己网络里
if key in selfnet_dict.keys():
selfnet_dict[key] = pretrain_dict[key]
print('load weight finished')

return selfnet_dict
# pytorch
数学拉丁字母符号读法,写法
  • Table of Contents
  • Overview
haoyu

haoyu

Make it come true.
20 posts
6 categories
10 tags
GitHub
  1. 1. torchvision涵盖一下几种网络
    1. 1.1. Classification
    2. 1.2. Semantic Segmentation
    3. 1.3. Object Detection, Instance Segmentation and Person Keypoint Detection
  2. 2. Saving and loading models 保存和读取模型
    1. 2.1. 基本功能
  3. 3. Load Weights Demo
    1. 3.1. 需求
    2. 3.2. Code
© 2018 – 2019 Haoyu
Powered by Hexo v3.7.0
|
Theme – NexT.Gemini v7.3.0