This post covers a quick overview of axes in numpy (NB both numpy and pytorch use same representation) import numpy as np Axes 2D matrices For both numpy and pytorch, axis 0 = row, 1 = column Note how when we specify axis=0 for sum, we are collapsing along that (row) axis np.sum([[1…