
Batching is the process of grouping data samples into smaller chunks (batches) for efficient training. Automatic batching is the default behavior of DataLoader. When batch_size is specified, the DataLoader automatically collates individual fetched data samples into batches, typically with the first dimension representing as the batch dimension.

How DataLoader creates mini-batches?

When you create a DataLoader just specify the batch_size to create a mini-batches. During training, the DataLoader slices your dataset into multiple mini-batches for the given batch size. Each batch contains multiple data points (e.g., images, text samples) . The DataLoader returns the batched data (input features and labels) to the training loop.

batch_size (int, optional) -> how many samples per batch to load (default: 1).

import torch
from import DataLoader, TensorDataset

# sample dummy image tensors
image_data = torch.randn(1000, 3, 64, 64) 
labels = torch.randint(0, 10, (1000,))  

dataset = TensorDataset(image_data, labels)

#Split into batches
batch_size = 32
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

#to view every iterated batch
for batch_images, batch_labels in dataloader:
    print(f"Batch shape: {batch_images.shape}, Labels: {batch_labels}")


Batch shape: torch.Size([32, 3, 64, 64]), Labels: tensor([0, 3, 7, 1, 4, 2, 9, 7, 4, 7, 1, 8, 3, 4, 6, 3, 8, 7, 2, 8, 3, 2, 4, 9,
6, 3, 8, 0, 0, 1, 3, 0])
Batch shape: torch.Size([32, 3, 64, 64]), Labels: tensor([1, 1, 5, 5, 4, 6, 1, 1, 2, 3, 8, 3, 7, 0, 6, 3, 1, 7, 7, 9, 4, 0, 8, 0,
7, 4, 8, 1, 0, 6, 2, 5])
Batch shape: torch.Size([32, 3, 64, 64]), Labels: tensor([8, 2, 6, 2, 8, 2, 0, 0, 4, 4, 9, 4, 8, 5, 2, 7, 6, 5, 0, 5, 4, 3, 9, 8,
8, 8, 1, 7, 7, 1, 7, 9])
Batch shape: torch.Size([32, 3, 64, 64]), Labels: tensor([6, 6, 9, 3, 7, 2, 3, 6, 9, 4, 7, 8, 7, 7, 9, 4, 8, 3, 9, 9, 9, 9, 4, 2,
0, 0, 5, 4, 5, 7, 5, 7])
Batch shape: torch.Size([32, 3, 64, 64]), Labels: tensor([3, 0, 2, 3, 7, 8, 0, 6, 6, 9, 1, 4, 6, 2, 7, 9, 1, 2, 9, 8, 4, 6, 2, 3,
7, 3, 5, 3, 6, 7, 3, 1])
Batch shape: torch.Size([32, 3, 64, 64]), Labels: tensor([8, 8, 0, 4, 9, 1, 5, 4, 9, 4, 7, 5, 2, 2, 1, 6, 0, 3, 1, 1, 1, 8, 8, 5,
0, 6, 3, 3, 9, 3, 4, 7])
Batch shape: torch.Size([32, 3, 64, 64]), Labels: tensor([7, 4, 9, 6, 2, 1, 0, 0, 1, 1, 3, 3, 9, 3, 5, 7, 0, 2, 3, 5, 9, 2, 3, 0,
9, 9, 0, 1, 9, 0, 3, 2])

The above results indicates a batch of 32 images, each with 3 channels and a size of 64×64 pixels. The batch_labels will be a tensor of shape (32,) containing the class labels for each image in the batch.

Benefits of using mini-batches

  • Mini-Batches allows for parallel processing on GPUs. Thus, it speeds up computation.
  • It reduces the memory required by processing in batches not an entire dataset at once.
  • It is feasible to train larger datasets by memory optimization and reduces the overhead.
  • It provides a stable update to model weights.

Choosing the right batch size

The smaller batch sizes leads to a stable training which results in noisy updates. On Larger batch size leads to faster convergence and excessively larger batches results in slower convergence. So it is beneficial to experiment with different batch sizes and monitor the training performance to determine a optimal batch size. For most of the cases , batch size of 64 for larger model and 32 for smaller datasets are used often.

PyTorch DataLoader

PyTorch’s DataLoader is a powerful tool for efficiently loading and processing data for training deep learning models. It provides functionalities for batching, shuffling, and processing data, making it easier to work with large datasets. In this article, we’ll explore how PyTorch’s DataLoader works and how you can use it to streamline your data pipeline.

Table of Content

  • What is Pytorch DataLoader?
  • Importance of Batching, Shuffling, and Processing in Deep Learning
  • Batching
  • Shuffling
  • Processing Data
  • PyTorch Dataset class for Customizing data transformations

What is Pytorch DataLoader?

Importance of Batching, Shuffling, and Processing in Deep Learning

Processing Data

PyTorch Dataset class for Customizing data transformations

