©Valery Brozhinsky/Shutterstock.com

This post covers a quick overview of axes in numpy (NB both numpy and pytorch use same representation)

import numpy as np

Axes

2D matrices

For both numpy and pytorch, axis 0 = row, 1 = column

Note how when we specify axis=0 for sum, we are collapsing along that (row) axis

np.sum([[1, 0], [3, 5]], axis=0) #->[1+3, 0+5]>>array([4, 5])

If we flatten on axis 1 we remove a column dimension — (ie we dont sum along columns :-) )

np.sum([[1, 0], [3, 5]], axis=1) #->[1+0, 3+5]
>>array([1, 8])

3D arrays / tensors

Moving to 3D gets a bit more complicated, the trick is to use the brackets to work out the axes

a = np.random.randint(10, size=(2,2,2))
a, a.shape
>>
(array([[[9, 4],
[5, 4]],

[[4, 3],
[9, 0]]]),
(2, 2, 2))

axis 0, this refers to arrays in the outside bracket [ [[]],[[]] ]

a[0], a[1]
>> (array([[9, 4],
[5, 4]]),
array([[4, 3],
[9, 0]]))

axis 1 refers to elements in each of the axis 0 items, ie [[ [] ], [ [] ]]

a[0][0], a[0][1], a[1][0], a[1][1]
>>
(array([9, 4]), array([5, 4]), array([4, 3]), array([9, 0]))

axis 2 refers to the elements in each of axis 1 items, ie [[[items]],[[items]]]

ie we gradually peel away brackets as we go deeper

a[0][0][0], a[0][0][1], a[0][1][0], a[0][1][1] #etc
>>
(9, 4, 5, 4)

Flatten along axis 0 —ie add first item in each of the main sub-arrays,. This reduces our array to 2 dimensions

b = np.sum(a, axis=0) #-> [[9+4, 4+3],[5+9, 4+0]]
b, b.shape
>>
(array([[13, 7],
[14, 4]]),
(2, 2))

Now flatten along axis 1 — same as flattening on axis 0 in 2D for each of the 2D arrays in the overall array

np.sum(a, axis=1) #->[[9+5, 4+4],[4+9, 3+0]]
>>
array([[14, 8],
[13, 3]])

Flatten on axis 2 — same as flattening on axis 1 for each of the 2D arrays in or 3D array

np.sum(a, axis=2) #-> [[9+4, 5+4], [4+3, 9+0]]
>> array([[13, 9],
[ 7, 9]])

Geophysicist and Deep Learning Practitioner