TLDR: The dimension is the index of the axis you want to delete. EX: Having (5,4,6)
you want to have (5,6)
you do dim=1
because the second dimension is the one you want to delete.
You are a Python user, you started using Numpy / PyTorch and you encountered some code showing f(...,dim=x)
and now each time you want to do something dimensional you have to try dim=0
, dim=1
, dim=2
until your code works, but you have no ideas how it works. Well, this post is for you.
Let’s say you have a ndarray (Numpy Array) of size (3,4,5)
, this means that you have 3 matrices of size (4,5)
. You can think of it as a cube of size (3,4,5)
. The first dimension is the depth, the second is the height and the third is the width.
Usually you know what your output should look like like. For example if you think of the matrices containing some values at time \(t\) and you’d like to have the average over the time you’d like the result to be of shape (4,5)
. Well the code is simple

