Application: Finding Layer Activations
Forward Hooks can be very useful to calculate the activations that the mode learns. Consider a model that can detect cancer, using the model’s activations we can see where actually the model is focusing on the image.
Implementation:
Let us build a simple CNN model in PyTorch, consisting of 3 layers, first layer being Convolution Layer, then an Average pooling layer and finally a Linear layer. We will try to get the activations from the pooling layer by registering a forward hook on it.
import torch
import torch.nn as nn
# Define a simple CNN model
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.relu = nn.ReLU(inplace=True)
self.pool = nn.MaxPool2d(2, 2)
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.pool(x)
return x
# Create an instance of the CNN model
model = CNN()
The forward hook function has 3 arguments, module, input and output. It returns an updated output according to the function or None. It should have the following signature:
hook(module, input, output) -> None or modified output
Now let us build a hook that can collect activations and store them in a dictionary data structure.
feats = {}
def hook_func(module, input, output):
feats['feat'] = output.detach()
Registering a forward hook on the Pooling Layer
model.pool.register_forward_hook(hook_func)
Suppose we fed the model with an image of 1x1x28x28 (a single grayscale image of dimension 28×28) and now want the features.
x= torch.randn(1,1,28,28)
output = model(x)
This step ensures that the activations are saved in the dictionary.
This code doesn’t involve the training of the model. To use the activation functions the model should be trained first before registering the hooks. If the shape of the dictionary is printed the output would be like this. The actual data in the dictionary is too large.
print(feats['feat'].shape)
#output -> torch.Size([1, 16, 26, 26])
What are PyTorch Hooks and how are they applied in neural network layers?
PyTorch hooks are a powerful mechanism for gaining insights into the behavior of neural networks during both forward and backward passes. They allow you to attach custom functions (hooks) to tensors and modules within your neural network, enabling you to monitor, modify, or record various aspects of the computation graph.
Hooks provides us with a way to inspect and manipulate the input, output, and gradients of individual layers in your network. Hooks are registered on specific layers of the network, from which you can monitor activations, and gradients, or even modify them for customization of the network. Hooks are employed in neural networks to perform various tasks such as visualization, debugging, feature extraction, gradient manipulation, and more.
Hooks can be applied to two objects.
- tensors
- ‘torch.nn.Module’ objects