How to load CIFAR10 Dataset in Pytorch?
To load the dataset, you need to use torchvision.datasets.CIFAR10() function.
Syntax: torchvision.datasets.CIFAR10(root: Union[str, Path], train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False)
Parameters:
- root (str or pathlib.Path) – Root directory of dataset where directory cifar-10-batches-py exists or will be saved to if download is set to True.
- train (bool, optional) – If True, creates dataset from training set, otherwise creates from test set.
- transform (callable, optional) – A function/transform that takes in a PIL image and returns a transformed version. E.g, transforms.RandomCrop
- target_transform (callable, optional) – A function/transform that takes in the target and transforms it.
- download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
Loading and displaying CIFAR-10 images with labels, here’s a streamlined approach:
import torchvision.transforms as transforms, torchvision, matplotlib.pyplot as plt
trainset = torchvision.datasets.CIFAR10(root='./data',
train=True,
download=True,
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]))
trainloader = torch.utils.data.DataLoader(trainset,
batch_size=4,
shuffle=True)
images, labels = next(iter(trainloader))
plt.imshow(torchvision.utils.make_grid(images).permute(1, 2, 0) / 2 + 0.5);
plt.title(' '.join(trainset.classes[label] for label in labels)); plt.show()
Output:
How to load CIFAR10 Dataset in Pytorch?
The CIFAR-10 dataset is a popular resource for training machine learning models, especially in the field of image recognition. It consists of 60,000 32×32 color images in 10 different classes, with 6,000 images per class. The dataset is divided into 50,000 training images and 10,000 testing images. In this article, we will see how we can load CIFAR10 dataset in Pytorch.