This post covers a quick overview of axes in numpy (NB both numpy and pytorch use same representation)
import numpy as np
Axes
2D matrices
For both numpy and pytorch, axis 0 = row, 1 = column
Note how when we specify axis=0 for sum, we are collapsing along that (row) axis
np.sum([[1, 0], [3, 5]], axis=0) #->[1+3, 0+5]>>array([4, 5])
If we flatten on axis 1 we remove a column dimension — (ie we dont sum along columns :-) )
np.sum([[1, 0], [3, 5]], axis=1) #->[1+0, 3+5]
>>array([1, 8])
3D arrays / tensors
Moving to 3D gets a bit more complicated, the trick is to use the brackets to work out the axes
a = np.random.randint(10, size=(2,2,2))
a, a.shape
>>(array([[[9, 4],
[5, 4]],
[[4, 3],
[9, 0]]]),
(2, 2, 2))
axis 0, this refers to arrays in the outside bracket [ [[]],[[]] ]
a[0], a[1]
>> (array([[9, 4],
[5, 4]]),
array([[4, 3],
[9, 0]]))
axis 1 refers to elements in each of the axis 0 items, ie [[ [] ], [ [] ]]
a[0][0], a[0][1], a[1][0], a[1][1]
>>(array([9, 4]), array([5, 4]), array([4, 3]), array([9, 0]))
axis 2 refers to the elements in each of axis 1 items, ie [[[items]],[[items]]]
ie we gradually peel away brackets as we go deeper
a[0][0][0], a[0][0][1], a[0][1][0], a[0][1][1] #etc
>>(9, 4, 5, 4)
Flatten along axis 0 —ie add first item in each of the main sub-arrays,. This reduces our array to 2 dimensions
b = np.sum(a, axis=0) #-> [[9+4, 4+3],[5+9, 4+0]]
b, b.shape
>>(array([[13, 7],
[14, 4]]),
(2, 2))
Now flatten along axis 1 — same as flattening on axis 0 in 2D for each of the 2D arrays in the overall array
np.sum(a, axis=1) #->[[9+5, 4+4],[4+9, 3+0]]
>>array([[14, 8],
[13, 3]])
Flatten on axis 2 — same as flattening on axis 1 for each of the 2D arrays in or 3D array
np.sum(a, axis=2) #-> [[9+4, 5+4], [4+3, 9+0]]
>> array([[13, 9],
[ 7, 9]])