Torchvision Dataset

Loading demo ImageNet vision dataset in torchvision using Pytorch. Click here to download the dataset by signing up.

Python3




# import the torch and
# torchvision dataset packages.
import torch
import torchvision
 
# access the dataset in torchvision package using
# .datasets followed by dataset name.
imagenet_data = torchvision.datasets.ImageNet('path/to/imagenet_root/')


Code Explanation:

  • The procedure is almost the same as loading the audio data.
  • Here, instead of torchaudio, torchvision has to be imported.
  • Use the torchvision function with the datasets accessor, followed by the dataset name.
  • Now, pass the path in which the dataset is present. Since the ImageNet dataset is no longer publicly accessible, download the root data in your local system and pass the path to this function. This will comfortably load the vision data.

To load your custom image data, use torch.utils.data.DataLoader(data, batch_size, shuffle) as mentioned above.

Python3




# import necessary function
# from torchvision package
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
 
# specify the image dataset folder
data_dir = r'path to dataset\train'
 
# perform some transformations like resizing,
# centering and tensorconversion
# using transforms function
transform = transforms.Compose(
    [transforms.Resize(255),
     transforms.CenterCrop(224),
     transforms.ToTensor()])
 
# pass the image data folder and
# transform function to the datasets
# .imagefolder function
dataset = datasets.ImageFolder(data_dir,
                               transform=transform)
 
# now use dataloder function load the
# dataset in the specified transformation.
dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size=32,
                                         shuffle=True)
 
# iter function iterates through all the
# images and labels and stores in two variables
images, labels = next(iter(dataloader))
 
# print the total no of samples
print('Number of samples: ', len(images))
image = images[2][0# load 3rd sample
 
# visualize the image
plt.imshow(image, cmap='gray')
 
# print the size of image
print("Image Size: ", image.size())
 
# print the label
print(label)


Output:

Image size: torch.Size([224,224])
tensor([0, 0, 0, 1, 1, 1])

Loading Data in Pytorch

In this article, we will discuss how to load different kinds of data in PyTorch.

For demonstration purposes, Pytorch comes with 3 divisions of datasets namely torchaudio, torchvision, and torchtext. We can leverage these demo datasets to understand how to load Sound, Image, and text data using Pytorch.

Similar Reads

Torchaudio Dataset

Loading demo yes_no audio dataset in torchaudio using Pytorch....

Torchvision Dataset

...

Torchtext Dataset

Loading demo ImageNet vision dataset in torchvision using Pytorch. Click here to download the dataset by signing up....