PyTorch BatchNorm2D Weights Explained

Understanding PyTorch BatchNorm2D and its weights

Javier
3 min readJul 17, 2022

The Batch Normalization Layer, proposed for the first time on 2015 on the famous paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift, has been the most used normalization layer in deep neural networks till very recently. As a widely used and well stablished normalization technique it has out-of-the-box implementations in every major Deep Learning framework: PyTorch, TensorFlow, MXNET…

Although there are many articles explaining the theory behind it, i.e. the internal covariance shift, very few go through the implementation details. So in this article we will focus on the BatchNorm2d weights as it is implemented in PyTorch, under the torch.nn.BatchNorm2d API, and will try to help you understand the core idea through some nice (hopefully) visualizations.

A few months ago I published a similar article covering the PyTorch Conv2D Weights and after having seen the good reception of it this one was down the line.

BatchNorm2d

The idea behind the Batch Normalization is very simple: given tensor with L feature maps it performs a standard normalization for each of its channels. This is, for every feature map l ∈ L, subtract its mean and divide by its standard deviation (square root of variance): (l-μ) /σ. Visually it can be depicted as shown below.

--

--

Javier

AI Research Engineer in Deep Learning. Living between the math and the code. A philosophic seeker interested in the meaning of everything from faith to AI.