This post covers a quick overview of axes in numpy (NB both numpy and pytorch use same representation)

`import numpy as np`

# Axes

## 2D matrices

For both numpy and pytorch, axis 0 = row, 1 = column

Note how when we specify axis=0 for sum, we are collapsing along that (row) axis

np.sum([[1, 0], [3, 5]], axis=0) #->[1+3, 0+5]>>array([4, 5])

If we flatten on axis 1 we remove a column dimension — (ie we dont sum along columns :-) )

`np.sum([[1, 0], [3, 5]], axis=1) #->[1+0, 3+5]`

>>array([1, 8])

# 3D arrays / tensors

Moving to 3D gets a bit more complicated, the trick is to use the brackets to work out the axes

a = np.random.randint(10, size=(2,2,2))

a, a.shape

>>(array([[[9, 4],

[5, 4]],

[[4, 3],

[9, 0]]]),

(2, 2, 2))

axis 0, this refers to arrays in the outside bracket [ *[[]]*,*[[]]* ]

`a[0], a[1]`

>> (array([[9, 4],

[5, 4]]),

array([[4, 3],

[9, 0]]))

axis 1 refers to elements in each of the axis 0 items, ie [[ *[]* ], [ *[]* ]]

a[0][0], a[0][1], a[1][0], a[1][1]

>>(array([9, 4]), array([5, 4]), array([4, 3]), array([9, 0]))

axis 2 refers to the elements in each of axis 1 items, ie [[[*items*]],[[*items*]]]

ie we gradually peel away brackets as we go deeper

a[0][0][0], a[0][0][1], a[0][1][0], a[0][1][1] #etc

>>(9, 4, 5, 4)

Flatten along axis 0 —ie add first item in each of the main sub-arrays,. This reduces our array to 2 dimensions

b = np.sum(a, axis=0) #-> [[9+4, 4+3],[5+9, 4+0]]

b, b.shape

>>(array([[13, 7],

[14, 4]]),

(2, 2))

Now flatten along axis 1 — same as flattening on axis 0 in 2D for each of the 2D arrays in the overall array

np.sum(a, axis=1) #->[[9+5, 4+4],[4+9, 3+0]]

>>array([[14, 8],

[13, 3]])

Flatten on axis 2 — same as flattening on axis 1 for each of the 2D arrays in or 3D array

`np.sum(a, axis=2) #-> [[9+4, 5+4], [4+3, 9+0]]`

>> array([[13, 9],

[ 7, 9]])