
本文是《PyTorch官方教程中文版》系列文章之一,目录链接:[翻译]PyTorch官方教程中文版:目录
本文翻译自PyTorch官方网站,链接地址:TRANSFORMS。
数据并不总是以机器学习算法所需的形式出现,我们使用transforms来对数据进行一些转换,使其适合用来训练。
TorchVision 的所有数据集都有两个参数,transform参数用于修改样本的特征,target_transform参数用于修改样本的标签,它们都是包含转换逻辑的可调用对象(译者:可调用对象的行为类似函数,可以像函数一样调用)。torchvision.transforms模块提供了几种可直接使用的转换。
FashionMNIST数据集的样本特征使用PIL图像格式,标签是整数。对于训练,我们需要的特征是归一化的张量(tensors),标签是one-hot编码的张量。为了执行这些转换,我们使用ToTensor和Lambda。
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
ds = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)
上述代码输出:
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
0%| | 0/26421880 [00:00<?, ?it/s]
0%| | 65536/26421880 [00:00<01:12, 361585.83it/s]
1%| | 229376/26421880 [00:00<00:38, 682117.94it/s]
3%|3 | 917504/26421880 [00:00<00:09, 2593359.31it/s]
7%|7 | 1933312/26421880 [00:00<00:05, 4087632.87it/s]
17%|#7 | 4587520/26421880 [00:00<00:02, 9920642.03it/s]
26%|##5 | 6750208/26421880 [00:00<00:01, 11074410.81it/s]
35%|###5 | 9338880/26421880 [00:01<00:01, 14550588.37it/s]
44%|####3 | 11567104/26421880 [00:01<00:01, 14128661.76it/s]
54%|#####3 | 14188544/26421880 [00:01<00:00, 16762615.40it/s]
62%|######2 | 16416768/26421880 [00:01<00:00, 15514163.88it/s]
72%|#######1 | 19005440/26421880 [00:01<00:00, 17731433.65it/s]
81%|######## | 21299200/26421880 [00:01<00:00, 16291591.79it/s]
90%|######### | 23855104/26421880 [00:01<00:00, 18223944.32it/s]
99%|#########9| 26214400/26421880 [00:02<00:00, 16697332.79it/s]
100%|##########| 26421880/26421880 [00:02<00:00, 13192758.02it/s]
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
0%| | 0/29515 [00:00<?, ?it/s]
100%|##########| 29515/29515 [00:00<00:00, 328001.24it/s]
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
0%| | 0/4422102 [00:00<?, ?it/s]
1%|1 | 65536/4422102 [00:00<00:11, 364942.89it/s]
5%|5 | 229376/4422102 [00:00<00:06, 685269.27it/s]
17%|#7 | 753664/4422102 [00:00<00:01, 2100900.34it/s]
41%|####1 | 1835008/4422102 [00:00<00:00, 3993554.00it/s]
93%|#########3| 4128768/4422102 [00:00<00:00, 8952772.00it/s]
100%|##########| 4422102/4422102 [00:00<00:00, 6003434.24it/s]
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
0%| | 0/5148 [00:00<?, ?it/s]
100%|##########| 5148/5148 [00:00<00:00, 41443909.77it/s]
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw
ToTensor()
ToTensor 将 PIL 图像或 NumPy ndarray 转换为 FloatTensor,并将图像的像素值缩放到[0.0, 1.0]范围内。
target_transform = Lambda(lambda y: torch.zeros(
10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
Lambda 转换
Lambda 转换使用 lambda 函数来执行转换。在本例中,我们定义了一个 lambda 函数来将整数转换为 one-hot 编码的张量。它首先创建一个大小为 10(标签的总类别数)的零张量,并调用 “scatter_” 函数进行one-hot编码。
target_transform = Lambda(lambda y: torch.zeros(
10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
进一步阅读
芸芸小站首发,阅读原文:http://xiaoyunyun.net/index.php/archives/282.html