How to load CIFAR10 Dataset in Pytorch?

To load the dataset, you need to use torchvision.datasets.CIFAR10() function.

Syntax: torchvision.datasets.CIFAR10(root: Union[str, Path], train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False)

Parameters:

  • root (str or pathlib.Path) – Root directory of dataset where directory cifar-10-batches-py exists or will be saved to if download is set to True.
  • train (bool, optional) – If True, creates dataset from training set, otherwise creates from test set.
  • transform (callable, optional) – A function/transform that takes in a PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
  • download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.


Loading and displaying CIFAR-10 images with labels, here’s a streamlined approach:

Python
import torchvision.transforms as transforms, torchvision, matplotlib.pyplot as plt
trainset = torchvision.datasets.CIFAR10(root='./data', 
                                        train=True, 
                                        download=True,
                                        transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]))
trainloader = torch.utils.data.DataLoader(trainset,
                                          batch_size=4, 
                                          shuffle=True)
images, labels = next(iter(trainloader))
plt.imshow(torchvision.utils.make_grid(images).permute(1, 2, 0) / 2 + 0.5); 
plt.title(' '.join(trainset.classes[label] for label in labels)); plt.show()

Output:


How to load CIFAR10 Dataset in Pytorch?

The CIFAR-10 dataset is a popular resource for training machine learning models, especially in the field of image recognition. It consists of 60,000 32×32 color images in 10 different classes, with 6,000 images per class. The dataset is divided into 50,000 training images and 10,000 testing images. In this article, we will see how we can load CIFAR10 dataset in Pytorch.

Similar Reads

What is the CIFAR10 Datasets in Pytorch?

It is a fundamental dataset for training and testing machine learning models, particularly in the context of computer vision....

How to load CIFAR10 Dataset in Pytorch?

To load the dataset, you need to use torchvision.datasets.CIFAR10() function....

Use Cases Cifar10 Dataset in Pytorch

The CIFAR-10 dataset, due to its straightforward yet challenging setup, has become a staple in various machine learning tasks and experiments. Here are some in-depth explanations of its common use cases:...