Batch Normalization in PyTorch
In the following pseudo code, we have build a simple neural network with batch normalization using PyTorch. We have define a subclass of ‘nn.Module‘ and added the ‘nn.BatchNorm1D‘ after the first fully connected layer to normalize the activations.
We have used ‘nn.BatchNorm1D’ as the input data is one-dimensional, but for two-dimensional data, especially for Convolutional Neural Networks ‘BatchNorm2D’ is used.
import torch
import torch.nn as nn
# Define a simple model
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1 = nn.Linear(784, 64)
self.bn = nn.BatchNorm1d(64) # Add Batch Normalization layer
self.relu = nn.ReLU()
self.fc2 = nn.Linear(64, 10)
self.softmax = nn.Softmax(dim=1)
def forward(self, x):
x = self.fc1(x)
x = self.bn(x) # Apply Batch Normalization
x = self.relu(x)
x = self.fc2(x)
x = self.softmax(x)
return x
# Instantiate the model
model = Model()
# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Train the model
for epoch in range(5):
for inputs, labels in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
What is Batch Normalization In Deep Learning?
Internal covariate shift is a major challenge encountered while training deep learning models. Batch normalization was introduced to address this issue. In this article, we are going to learn the fundamentals and need of Batch normalization. We are also going to perform batch normalization.
Table of Content
- What is Batch Normalization?
- Need for Batch Normalization
- Fundamentals of Batch Normalization
- Batch Normalization in TensorFlow
- Batch Normalization in PyTorch
- Benefits of Batch Normalization
- Conclusion