How Does torch.argmax Work for 4-Dimensions

If we didn’t set the keepdims=True in argmax() method for a 4-Dimensional input tensor with shape [1,2,3,4] and with axis=0, it will return an output tensor of shape [2,3,4]. Whereas for axis=1 the argmax() method returns a tensor with shape [1,3,4] which will be similar to another axis. so when we apply argmax method across any axis/dimension by default it will collapses that axis or dimension because its values are replaced by a single index.

If we set the keepdims=True in argmax() method then it won’t remove that dimension instead, it keeps it as one. For example a 4-D tensor with shape [1,2,3,4] the argmax() method along the axis=1 returns a tensor with shape [1,1,3,4].

Example 1

Here in the below program we generated a 4-dimensional random tensor using randn() method and passed it to argmax() method and checked the results along the different axis with keepdims=False or None.

Python3




# import necessary libraries
import torch
  
# define a random 4D tensor
A = torch.randn(1, 2, 3, 4)
print("Tensor-A:", A)
print(A.shape)
  
# use argmax method on 4d tensor along axis-0
print('---Output tensor along axis-0---')
print(torch.argmax(A, axis=0, keepdims=False))
print(torch.argmax(A, axis=0, keepdims=False).shape)
  
# use argmax method on 4d tensor along axis-2
print('---Output tensor along axis-2---')
print(torch.argmax(A, axis=2))
print(torch.argmax(A, axis=2).shape)


Output

Tensor-A: tensor([[[[ 0.2672,  0.6414, -0.7371, -0.8712],

          [ 0.9414, -1.2926, -1.0787,  1.7124],

          [-1.1063, -1.7132,  1.5767, -1.7195]],

         [[-0.7871, -1.3260,  0.1592, -0.0543],

          [ 1.8193, -1.8586, -0.6683,  0.3800],

          [ 1.8769, -0.9481, -0.4193,  0.4439]]]])

torch.Size([1, 2, 3, 4])

—Output tensor along axis-0—

tensor([[[0, 0, 0, 0],

         [0, 0, 0, 0],

         [0, 0, 0, 0]],

        [[0, 0, 0, 0],

         [0, 0, 0, 0],

         [0, 0, 0, 0]]])

torch.Size([2, 3, 4])

—Output tensor along axis-2—

tensor([[[1, 0, 2, 1],

         [2, 2, 0, 2]]])

torch.Size([1, 2, 4])

Example 2

Here in this program we generated a 4-dimensional random tensor using randn() method and passed it to argmax() method and checked the results along the different axis with keepdims value is set to True.

Python3




# import necessary libraries
import torch
  
# define a random 4D tensor
A = torch.randn(1, 2, 3, 4)
print("Tensor-A:", A)
print(A.shape)
  
# use argmax method on 4d tensor along axis-2
print('---Output tensor along axis-2---')
print(torch.argmax(A, axis=2, keepdims=True))
print(torch.argmax(A, axis=2, keepdims=True).shape)
  
# use argmax method on 4d tensor along axis-3
print('---Output tensor along axis-3---')
print(torch.argmax(A, axis=3, keepdims=True))
print(torch.argmax(A, axis=3, keepdims=True).shape)


Output

Tensor-A: tensor([[[[ 0.8328, -0.6209,  0.0998,  0.4570],

          [ 0.1988, -0.2921,  1.7013, -0.8665],

          [ 0.6360,  0.0828,  0.3932,  0.2918]],

         [[ 0.0380, -0.0488,  1.0596,  0.8984],

          [-1.5110, -0.1987,  1.0706,  1.5212],

          [-0.0235,  0.3309,  0.8487, -1.9038]]]])

torch.Size([1, 2, 3, 4])

—Output tensor along axis-2—

tensor([[[[0, 2, 1, 0]],

         [[0, 2, 1, 1]]]])

torch.Size([1, 2, 1, 4])

—Output tensor along axis-3—

tensor([[[[0],

          [2],

          [0]],

         [[2],

          [3],

          [2]]]])

torch.Size([1, 2, 3, 1])



How Does torch.argmax Work for 4-Dimensions in Pytorch

In this article, we are going to discuss how does the torch.argmax work for 4-Dimensions with detailed examples.

Similar Reads

Torch.argmax() Method

Torch.argmax() method accepts a tensor and returns the indices of the maximum values of the input tensor across a specified dimension/axis. If the input tensor exists with multiple maximal values then the function will return the index of the first maximal element. Let’s look into the syntax of Torch.argmax() method along with its parameters....

Working with argmax

In higher dimensions torch.argmax method returns the list of indices of maximum values according to the specified axis. we can understand it with an example of how argmax() method works on 2 Dimensional tensors....

How Does torch.argmax Work for 4-Dimensions

If we didn’t set the keepdims=True in argmax() method for a 4-Dimensional input tensor with shape [1,2,3,4] and with axis=0, it will return an output tensor of shape [2,3,4]. Whereas for axis=1 the argmax() method returns a tensor with shape [1,3,4] which will be similar to another axis. so when we apply argmax method across any axis/dimension by default it will collapses that axis or dimension because its values are replaced by a single index....