Numpy Axes

©Valery Brozhinsky/Shutterstock.com

This post covers a quick overview of axes in numpy (NB both numpy and pytorch use same representation)

Axes

For both numpy and pytorch, axis 0 = row, 1 = column

Note how when we specify axis=0 for sum, we are collapsing along that (row) axis

If we flatten on axis 1 we remove a column dimension — (ie we dont sum along columns :-) )

3D arrays / tensors

Moving to 3D gets a bit more complicated, the trick is to use the brackets to work out the axes

axis 0, this refers to arrays in the outside bracket [ [[]],[[]] ]

axis 1 refers to elements in each of the axis 0 items, ie [[ [] ], [ [] ]]

axis 2 refers to the elements in each of axis 1 items, ie [[[items]],[[items]]]

ie we gradually peel away brackets as we go deeper

Flatten along axis 0 —ie add first item in each of the main sub-arrays,. This reduces our array to 2 dimensions

Now flatten along axis 1 — same as flattening on axis 0 in 2D for each of the 2D arrays in the overall array

Flatten on axis 2 — same as flattening on axis 1 for each of the 2D arrays in or 3D array

--

--

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store