DataLoader and DataSets

PyTorch provides some helper functions to load data, shuffling, and augmentations. This section we will learn more about it.

Data loading in PyTorch can be separated in 2 parts:

  • Data must be wrapped on a Dataset parent class where the methods __getitem__ and __len__ must be overrided. Not that at this point the data is not loaded on memory. PyTorch will only load what is needed to the memory.

  • Use a Dataloader that will actually read the data and put into memory.

The example shown here is going to be used to load data from our driverless car demo.

Dataset parent class

So let's create a class that is inherited from the Dataset class, here we will provide functions to gather data and also to know the number of items, but we will not load the whole thing in memory.

import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import numpy as np
import os
from PIL import Image
import matplotlib.pyplot as plt
FOLDER_DATASET = "./Track_1_Wheel_Test/"
plt.ion()

class DriveData(Dataset):
    __xs = []
    __ys = []

    def __init__(self, folder_dataset, transform=None):
        self.transform = transform
        # Open and load text file including the whole training data
        with open(folder_dataset + "data.txt") as f:
            for line in f:
                # Image path
                self.__xs.append(folder_dataset + line.split()[0])        
                # Steering wheel label
                self.__ys.append(np.float32(line.split()[1]))

    # Override to give PyTorch access to any image on the dataset
    def __getitem__(self, index):
        img = Image.open(self.__xs[index])
        img = img.convert('RGB')
        if self.transform is not None:
            img = self.transform(img)

        # Convert image and label to torch tensors
        img = torch.from_numpy(np.asarray(img))
        label = torch.from_numpy(np.asarray(self.__ys[index]).reshape([1,1]))
        return img, label

    # Override to give PyTorch size of dataset
    def __len__(self):
        return len(self.__xs)

Instantiating the dataset and passing to the dataloader

dset_train = DriveData(FOLDER_DATASET)
train_loader = DataLoader(dset_train, batch_size=10, shuffle=True, num_workers=1)

Now pytorch will manage for you all the shuffling management and loading (multi-threaded) of your data.

# Get a batch of training data
imgs, steering_angle = next(iter(train_loader))
print('Batch shape:',imgs.numpy().shape)
plt.imshow(imgs.numpy()[0,:,:,:])
plt.show()
plt.imshow(imgs.numpy()[-1,:,:,:])
plt.show()

# If you want the batch on a for-loop
# for batch_idx, (data, target) in enumerate(train_loader):

Tranformation

PyTorch also has a mechanism to apply simple transformations on the image

References:

Last updated