PyTorch provides two data primitives: torch.utils.data.DataLoader
and torch.utils.data.Dataset
that allow you to use pre-loaded datasets as well as your own data.
Dataset
stores the samples and their corresponding labels, and DataLoader
wraps an iterable around the Dataset
to enable easy access to the samples.
loading a dataset
load the Fashion-MNIST dataset from TorchVision. The torchvision package consists of popular datasets, model architectures, and common image transformations for computer vision.
|
|
Iterating and Visualizing the Dataset
|
|
Training with DataLoaders
While training a model, we use DataLoader
to pass samples in “minibatches”, reshuffle the data at every epoch to reduce model overfitting, and use Python’s multiprocessing to speed up data retrieval.
To use DataLoader
, we need to set the followings paraments:
- dataset-dataset from which to load the data
- batch_size-how many samples per batch to load
- shuffle-set to True to have the data reshuffled at every epoch (default: False)
|
|
We have loaded that dataset into the Dataloader
and can iterate through the dataset as needed. Each iteration below returns a batch of train_features
and train_labels
(containing batch_size=64
features and labels respectively). Because we specified shuffle=True
, after we iterate over all batches the data is shuffled (for finer-grained control over the data loading order.
|
|
iter(object[, sentinel])
用于生成迭代器,传入参数object必须为支持迭代的对象,next()
返回迭代器下一项。
Normalizatioin
Normalization is a common data pre-processing technique that is applied to scale or transform the data to make sure there’s an equal learning contribution from each feature.
Transforms
We use transforms to perform some manipulation of the data and make it suitable for training.
transform
to modify the features and target_transform
to modify the labels.
ToTensor
converts a PIL image or NumPy ndarray
into a FloatTensor
and scales the image’s pixel intensity values in the range [0., 1.]