How Does torch.argmax Work for 4-Dimensions
If we didn’t set the keepdims=True in argmax() method for a 4-Dimensional input tensor with shape [1,2,3,4] and with axis=0, it will return an output tensor of shape [2,3,4]. Whereas for axis=1 the argmax() method returns a tensor with shape [1,3,4] which will be similar to another axis. so when we apply argmax method across any axis/dimension by default it will collapses that axis or dimension because its values are replaced by a single index.
If we set the keepdims=True in argmax() method then it won’t remove that dimension instead, it keeps it as one. For example a 4-D tensor with shape [1,2,3,4] the argmax() method along the axis=1 returns a tensor with shape [1,1,3,4].
Example 1
Here in the below program we generated a 4-dimensional random tensor using randn() method and passed it to argmax() method and checked the results along the different axis with keepdims=False or None.
Python3
# import necessary libraries import torch # define a random 4D tensor A = torch.randn( 1 , 2 , 3 , 4 ) print ( "Tensor-A:" , A) print (A.shape) # use argmax method on 4d tensor along axis-0 print ( '---Output tensor along axis-0---' ) print (torch.argmax(A, axis = 0 , keepdims = False )) print (torch.argmax(A, axis = 0 , keepdims = False ).shape) # use argmax method on 4d tensor along axis-2 print ( '---Output tensor along axis-2---' ) print (torch.argmax(A, axis = 2 )) print (torch.argmax(A, axis = 2 ).shape) |
Output
Tensor-A: tensor([[[[ 0.2672, 0.6414, -0.7371, -0.8712],
[ 0.9414, -1.2926, -1.0787, 1.7124],
[-1.1063, -1.7132, 1.5767, -1.7195]],
[[-0.7871, -1.3260, 0.1592, -0.0543],
[ 1.8193, -1.8586, -0.6683, 0.3800],
[ 1.8769, -0.9481, -0.4193, 0.4439]]]])
torch.Size([1, 2, 3, 4])
—Output tensor along axis-0—
tensor([[[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]]])
torch.Size([2, 3, 4])
—Output tensor along axis-2—
tensor([[[1, 0, 2, 1],
[2, 2, 0, 2]]])
torch.Size([1, 2, 4])
Example 2
Here in this program we generated a 4-dimensional random tensor using randn() method and passed it to argmax() method and checked the results along the different axis with keepdims value is set to True.
Python3
# import necessary libraries import torch # define a random 4D tensor A = torch.randn( 1 , 2 , 3 , 4 ) print ( "Tensor-A:" , A) print (A.shape) # use argmax method on 4d tensor along axis-2 print ( '---Output tensor along axis-2---' ) print (torch.argmax(A, axis = 2 , keepdims = True )) print (torch.argmax(A, axis = 2 , keepdims = True ).shape) # use argmax method on 4d tensor along axis-3 print ( '---Output tensor along axis-3---' ) print (torch.argmax(A, axis = 3 , keepdims = True )) print (torch.argmax(A, axis = 3 , keepdims = True ).shape) |
Output
Tensor-A: tensor([[[[ 0.8328, -0.6209, 0.0998, 0.4570],
[ 0.1988, -0.2921, 1.7013, -0.8665],
[ 0.6360, 0.0828, 0.3932, 0.2918]],
[[ 0.0380, -0.0488, 1.0596, 0.8984],
[-1.5110, -0.1987, 1.0706, 1.5212],
[-0.0235, 0.3309, 0.8487, -1.9038]]]])
torch.Size([1, 2, 3, 4])
—Output tensor along axis-2—
tensor([[[[0, 2, 1, 0]],
[[0, 2, 1, 1]]]])
torch.Size([1, 2, 1, 4])
—Output tensor along axis-3—
tensor([[[[0],
[2],
[0]],
[[2],
[3],
[2]]]])
torch.Size([1, 2, 3, 1])
How Does torch.argmax Work for 4-Dimensions in Pytorch
In this article, we are going to discuss how does the torch.argmax work for 4-Dimensions with detailed examples.