
本文是《PyTorch官方教程中文版》系列文章之一,目录链接:[翻译]PyTorch官方教程中文版:目录
本文翻译自PyTorch官方网站,链接地址:Save and Load the Model。
保存和加载模型
本文将介绍如何通过保存和加载模型来保持模型状态,以及如何运行模型。
import torch
import torchvision.models as models
保存和加载模型权重
PyTorch 模型将学习得到的参数存储在名为 state_dict 的内部状态字典中,这些可以通过 torch.save 方法进行保存:
model = models.vgg16(weights='IMAGENET1K_V1')
torch.save(model.state_dict(), 'model_weights.pth')
上述代码输出:
Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /var/lib/jenkins/.cache/torch/hub/checkpoints/vgg16-397923af.pth
0%| | 0.00/528M [00:00<?, ?B/s]
4%|3 | 18.7M/528M [00:00<00:02, 196MB/s]
8%|7 | 39.7M/528M [00:00<00:02, 210MB/s]
12%|#1 | 60.8M/528M [00:00<00:02, 215MB/s]
15%|#5 | 81.7M/528M [00:00<00:02, 217MB/s]
20%|#9 | 103M/528M [00:00<00:02, 219MB/s]
25%|##4 | 131M/528M [00:00<00:01, 246MB/s]
31%|### | 164M/528M [00:00<00:01, 276MB/s]
37%|###7 | 196M/528M [00:00<00:01, 296MB/s]
42%|####2 | 224M/528M [00:00<00:01, 274MB/s]
47%|####7 | 251M/528M [00:01<00:01, 255MB/s]
52%|#####2 | 275M/528M [00:01<00:01, 244MB/s]
57%|#####6 | 299M/528M [00:01<00:01, 231MB/s]
61%|###### | 321M/528M [00:01<00:00, 228MB/s]
66%|######5 | 346M/528M [00:01<00:00, 237MB/s]
72%|#######1 | 378M/528M [00:01<00:00, 264MB/s]
77%|#######6 | 406M/528M [00:01<00:00, 273MB/s]
82%|########1 | 432M/528M [00:01<00:00, 256MB/s]
87%|########6 | 457M/528M [00:01<00:00, 245MB/s]
91%|#########1| 481M/528M [00:02<00:00, 220MB/s]
95%|#########5| 502M/528M [00:02<00:00, 220MB/s]
99%|#########9| 524M/528M [00:02<00:00, 217MB/s]
100%|##########| 528M/528M [00:02<00:00, 240MB/s]
要加载模型的权重,首先要创建一个该模型的实例,然后使用 load_state_dict() 方法加载权重。
model = models.vgg16() # we do not specify ``weights``, i.e. create untrained model
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()
上述代码输出:
VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace=True)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace=True)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace=True)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace=True)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace=True)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace=True)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace=True)
(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): ReLU(inplace=True)
(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace=True)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace=True)
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(25): ReLU(inplace=True)
(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(27): ReLU(inplace=True)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace=True)
(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)
注意:确保在开始推理之前调用 model.eval() 方法,把模型设置为推理模式,否则将导致推理结果不一致。
保存和加载模型的结构
如果只加载模型权重,那么需要先知道模型使用的类(class),并首先创建模型类的实例,因为模型类定义了神经网络的结构。当我们需要把模型的结构和权重一起保存时,可以将模型传递给保存函数:
torch.save(model, 'model.pth')
并像这样加载模型:
model = torch.load('model.pth')
注意:此方法依赖 Python 的 pickle 模块,并且依赖实际类定义(译者注:即加载模型时,定义模型结构的类要可用)。
相关教程
芸芸小站首发,阅读原文:http://xiaoyunyun.net/index.php/archives/325.html