[翻译]PyTorch官方教程中文版:转换

transform.jpg

本文是《PyTorch官方教程中文版》系列文章之一,目录链接:[翻译]PyTorch官方教程中文版:目录

本文翻译自PyTorch官方网站,链接地址:TRANSFORMS

数据并不总是以机器学习算法所需的形式出现,我们使用transforms来对数据进行一些转换,使其适合用来训练。

TorchVision 的所有数据集都有两个参数,transform参数用于修改样本的特征,target_transform参数用于修改样本的标签,它们都是包含转换逻辑的可调用对象(译者:可调用对象的行为类似函数,可以像函数一样调用)。torchvision.transforms模块提供了几种可直接使用的转换。

FashionMNIST数据集的样本特征使用PIL图像格式,标签是整数。对于训练,我们需要的特征是归一化的张量(tensors),标签是one-hot编码的张量。为了执行这些转换,我们使用ToTensor和Lambda。

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

上述代码输出:

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz

  0%|          | 0/26421880 [00:00<?, ?it/s]
  0%|          | 65536/26421880 [00:00<01:12, 361585.83it/s]
  1%|          | 229376/26421880 [00:00<00:38, 682117.94it/s]
  3%|3         | 917504/26421880 [00:00<00:09, 2593359.31it/s]
  7%|7         | 1933312/26421880 [00:00<00:05, 4087632.87it/s]
 17%|#7        | 4587520/26421880 [00:00<00:02, 9920642.03it/s]
 26%|##5       | 6750208/26421880 [00:00<00:01, 11074410.81it/s]
 35%|###5      | 9338880/26421880 [00:01<00:01, 14550588.37it/s]
 44%|####3     | 11567104/26421880 [00:01<00:01, 14128661.76it/s]
 54%|#####3    | 14188544/26421880 [00:01<00:00, 16762615.40it/s]
 62%|######2   | 16416768/26421880 [00:01<00:00, 15514163.88it/s]
 72%|#######1  | 19005440/26421880 [00:01<00:00, 17731433.65it/s]
 81%|########  | 21299200/26421880 [00:01<00:00, 16291591.79it/s]
 90%|######### | 23855104/26421880 [00:01<00:00, 18223944.32it/s]
 99%|#########9| 26214400/26421880 [00:02<00:00, 16697332.79it/s]
100%|##########| 26421880/26421880 [00:02<00:00, 13192758.02it/s]
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz

  0%|          | 0/29515 [00:00<?, ?it/s]
100%|##########| 29515/29515 [00:00<00:00, 328001.24it/s]
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz

  0%|          | 0/4422102 [00:00<?, ?it/s]
  1%|1         | 65536/4422102 [00:00<00:11, 364942.89it/s]
  5%|5         | 229376/4422102 [00:00<00:06, 685269.27it/s]
 17%|#7        | 753664/4422102 [00:00<00:01, 2100900.34it/s]
 41%|####1     | 1835008/4422102 [00:00<00:00, 3993554.00it/s]
 93%|#########3| 4128768/4422102 [00:00<00:00, 8952772.00it/s]
100%|##########| 4422102/4422102 [00:00<00:00, 6003434.24it/s]
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz

  0%|          | 0/5148 [00:00<?, ?it/s]
100%|##########| 5148/5148 [00:00<00:00, 41443909.77it/s]
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

ToTensor()

ToTensor 将 PIL 图像或 NumPy ndarray 转换为 FloatTensor,并将图像的像素值缩放到[0.0, 1.0]范围内。

target_transform = Lambda(lambda y: torch.zeros(
    10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))

Lambda 转换

Lambda 转换使用 lambda 函数来执行转换。在本例中,我们定义了一个 lambda 函数来将整数转换为 one-hot 编码的张量。它首先创建一个大小为 10(标签的总类别数)的零张量,并调用 “scatter_” 函数进行one-hot编码。

target_transform = Lambda(lambda y: torch.zeros(
    10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))

进一步阅读


芸芸小站首发,阅读原文:


最后编辑:2023年08月11日 ©版权所有,转载须保留原文链接

发表评论

正在加载 Emoji