Load a Computer Vision Dataset in PyTorch
Computer vision is a subset of Artificial Intelligence that gives the ability to the computer to understand images. In Deep Learning, Convolution Neural Network is used to process the image. For building the good we need a lot of images to process.
There are several ways to load a computer vision dataset in PyTorch, depending on the format of the dataset and the specific requirements of your project.
One popular method is to use the built-in PyTorch dataset classes, such as torchvision.datasets.’It provides a convenient way to load and preprocess common computer vision datasets, such as CIFAR-10 and ImageNet. For example, to load the CIFAR-10 dataset, you can use the following code:
Python3
# Import the necessary library import torchvision.datasets as datasets # Download the cifar Dataset cifar10_train = datasets.CIFAR10(root = "./data" , train = True , download = True ) cifar10_test = datasets.CIFAR10(root = "./data" , train = False , download = True ) |
Output:
The code above will download the CIFAR-10 dataset and save it in the ‘./data‘ directory.
Another method is using the ‘torch.utils.data.DataLoader class to load the data. This is more useful when the data is in your local machine and you would like to have the power of data augmentation and the ability to shuffle the data and also have the ability to specify the batch size. it has the advantages of customizing data loading order, batching, single or multi-process data loading, etc.
Here we can use transform.Compose function from torchvision to rotate, flip, normalize and convert it into tensor form from the image.
Python3
# Import the necessary library from torchvision import transforms from torch.utils.data import DataLoader # Image Transformation transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomRotation( 10 ), transforms.ToTensor(), transforms.Normalize([ 0.35 , 0.35 , 0.406 ], [ 0.30 , 0.34 , 0.35 ]) ]) # Load the dataset with transformation cifar10_train = datasets.CIFAR10(root = "./data" , train = True , download = False , transform = transform) cifar10_test = datasets.CIFAR10(root = "./data" , train = False , download = False , transform = transform) # Make the batch of size 16 train_loader = DataLoader(cifar10_train, batch_size = 32 , shuffle = True , num_workers = 2 ) test_loader = DataLoader(cifar10_test, batch_size = 32 , shuffle = False , num_workers = 2 ) |
View the train and test data
Python3
#Train Dataset print (train_loader.dataset) #Test Dataset print (test_loader.dataset) |
Output:
Dataset CIFAR10 Number of datapoints: 50000 Root location: ./data Split: Train StandardTransform Transform: Compose( RandomHorizontalFlip(p=0.5) RandomRotation(degrees=[-10.0, 10.0], interpolation=nearest, expand=False, fill=0) ToTensor() Normalize(mean=[0.35, 0.35, 0.406], std=[0.3, 0.34, 0.35]) ) Dataset CIFAR10 Number of datapoints: 10000 Root location: ./data Split: Test StandardTransform Transform: Compose( RandomHorizontalFlip(p=0.5) RandomRotation(degrees=[-10.0, 10.0], interpolation=nearest, expand=False, fill=0) ToTensor() Normalize(mean=[0.35, 0.35, 0.406], std=[0.3, 0.34, 0.35]) )
Plot the image:
Python3
# Iteration inputs, Class = next ( iter (train_loader)) #Define the class names class_name = { 0 : 'airplane' , 1 : 'automobile' , 2 : 'bird' , 3 : 'cat' , 4 : 'deer' , 5 : 'dog' , 6 : 'frog' , 7 : 'horse' , 8 : 'ship' , 9 : 'truck' } #Plot the figure plt.figure(figsize = ( 30 , 16 ), dpi = 1000 ) for i in range ( 32 ): plt.subplot( 4 , 8 ,i + 1 ) plt.imshow(inputs[i].numpy().transpose(( 1 , 2 , 0 ))) plt.axis( 'off' ) plt.title(class_name[ int (Class[i])]) plt.show() |
Output:
The other libraries like ‘albumentations‘ , can be used to load the dataset and preprocess the data. It all depends on the format of your data and what you are trying to achieve
You might also want to check the version of PyTorch you’re using, as well as the format of the dataset you’re trying to load. Some datasets might be in a custom format and you might need to write your own code to load it correctly.