Go to Practical Course
PracticalCourse

Speeding up model with fusing batch normalization and convolution

This is a common technique in quantized and inference optimized models.

Today we will try to understand how we can make our model a little bit faster on inference.

Wast amount of networks using Batch Normalization as a way to improve generalization. But during inference Batch Normalization is turned off and instead the approximated per-channel mean $\mu$ and variance $\sigma^2$ are used. And the cool thing is that we can implement the same behavior through 1x1 convolution. And even better, we can merge it with the preceding convolution.

Batch Normalization

Let $x$ be a signal (activation) within the network that we want to normalize. Given a set of such signals ${x_1, x_2, \ldots, x_n}$ ​ coming from processing different samples within a batch, each is normalized as follows:

\begin{aligned}\begin{gathered}\hat{x}_i = \gamma\frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta\\ \hat{x}_i = \frac{\gamma x_i}{\sqrt{\sigma^2 + \epsilon}} + \beta - \frac{\gamma\mu}{\sqrt{\sigma^2 + \epsilon}} \end{gathered} \end{aligned}

The values $\mu$ and $\sigma^2$ are the mean and variance computed over a batch, $\epsilon$ is a small constant included for numerical stability, $\gamma$ is the scaling factor and $\beta$ the shift factor. During training, $\mu$ and $\sigma$ are recomputed for each batch:

\begin{aligned}\begin{gather}\mu=\frac{1}{n}\sum x_i​\\ \sigma^2=\frac{1}{n}\sum (x_i - \mu)^2\end{gather}\end{aligned}

The parameters $\gamma$ and $\beta$ are slowly learned with gradient descent together with the other parameters of the network. During test time, we usually do not run the network on a batch of images. Thus, the previously mentioned formula for $\mu$ and $\sigma$ can't be used. Instead, we use their estimates computed during training by exponential moving average. Let us denote these approximations as $\hat{\mu}$ ​ and $\hat{\sigma}^2$

Nowadays, batch normalization is mostly used in convolutional neural networks for processing images. In this setting, there are mean and variance estimates, shift and scale parameters for each channel of the input feature map. We will denote these as $\mu_c$ ​ , $\sigma^2_c$ ​ , $\gamma_c$ ​ and $\beta_c$ ​ for channel $c$.

Solution

Implementing frozen Batch Normalization as a 1×1 Convolution

Given a feature map $F$ with shape $C\times H\times W$ order, to get its normalized version, $\hat{F}$, we need to run computation for each spatial position $i, j$ with using the formula from above for $\hat{x}_i$:

$$ \begin{pmatrix} \hat{F}_{1,i,j} \cr \hat{F}_{2,i,j} \cr \vdots \cr \hat{F}_{C-1,i,j} \cr \hat{F}_{C,i,j} \cr \end{pmatrix} = \begin{pmatrix} \frac{\gamma_1}{\sqrt{\hat{\sigma}^2_1 + \epsilon}}&0&\cdots&&0\cr 0&\frac{\gamma_2}{\sqrt{\hat{\sigma}^2_2 + \epsilon}}\cr \vdots&&\ddots&&\vdots\cr &&&\frac{\gamma_{C-1}}{\sqrt{\hat{\sigma}^2_{C-1} + \epsilon}}&0\cr 0&&\cdots&0&\frac{\gamma_C}{\sqrt{\hat{\sigma}^2_C + \epsilon}}\cr \end{pmatrix} \cdot \begin{pmatrix} F_{1,i,j} \cr F_{2,i,j} \cr \vdots \cr F_{C-1,i,j} \cr F_{C,i,j} \cr \end{pmatrix} + \begin{pmatrix} \beta_1 - \gamma_1\frac{\hat{\mu}_1}{\sqrt{\hat{\sigma}^2_1 + \epsilon}} \cr \beta_2 - \gamma_2\frac{\hat{\mu}_2}{\sqrt{\hat{\sigma}^2_2 + \epsilon}} \cr \vdots \cr \beta_{C-1} - \gamma_{C-1}\frac{\hat{\mu}_{C-1}}{\sqrt{\hat{\sigma}^2_{C-1} + \epsilon}} \cr \beta_C - \gamma_C\frac{\hat{\mu}_C}{\sqrt{\hat{\sigma}^2_C + \epsilon}} \cr \end{pmatrix} $$

We clearly see that this is $f(x) = W*x + b$ which can be implemented as a $1\times 1$ convolution. And even more, because BN often goes after convolution layers, we can fuse them into one.

Fusing batch normalization with a convolutional layer

Let, $\mathbf{W}_{BN}\in\mathbb{R}^{C\times C}$ and $\mathbf{b}_{BN}\in\mathbb{R}^{C}$ - are parameters of the BN
$\mathbf{W}_{conv}\in\mathbb{R}^{C\times(C_{prev}\cdot k^2)}$ and $\mathbf{b}_{conv}\in\mathbb{R}^{C}$ - are parameters of the Convolutional layer that precede BN
$F_{prev}$ - input to the convolutional
$C_{prev}$ - the number of channels of the input layer
$k$ - is the filter size.

$k\times k$ part of $F_{prev}$ reshaped into a $k^2\cdot C_{prev}$ vector $\mathbf{f}_{i,j}$, so the resulting formula will be: $$ \mathbf{\hat{f}}_{i,j}= \mathbf{W}_{BN}\cdot (\mathbf{W}_{conv}\cdot\mathbf{f}_{i,j} + \mathbf{b}_{conv}) + \mathbf{b}_{BN} $$

Thus, we can replace these two layers by a single convolutional layer with the following parameters:

Implementation in PyTorch

nn.Conv2d parameters:

nn.BatchNorm2d parameters:

Code


    import torch
    import torchvision
    
    def fuse(conv, bn):
    
        fused = torch.nn.Conv2d(
            conv.in_channels,
            conv.out_channels,
            kernel_size=conv.kernel_size,
            stride=conv.stride,
            padding=conv.padding,
            bias=True
        )
    
        # setting weights
        w_conv = conv.weight.clone().view(conv.out_channels, -1)
        w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps+bn.running_var)))
        fused.weight.copy_( torch.mm(w_bn, w_conv).view(fused.weight.size()) )
        
        # setting bias
        if conv.bias is not None:
            b_conv = conv.bias
        else:
            b_conv = torch.zeros( conv.weight.size(0) )
        b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(
                              torch.sqrt(bn.running_var + bn.eps)
                            )
        fused.bias.copy_( b_conv + b_bn )
    
        return fused
    
    # Testing
    # we need to turn off gradient calculation because we didn't write it
    torch.set_grad_enabled(False)
    x = torch.randn(16, 3, 256, 256)
    resnet18 = torchvision.models.resnet18(pretrained=True)
    # removing all learning variables, etc
    resnet18.eval()
    model = torch.nn.Sequential(
        resnet18.conv1,
        resnet18.bn1
    )
    f1 = model.forward(x)
    fused = fuse(model[0], model[1])
    f2 = fused.forward(x)
    d = (f1 - f2).mean().item()
    print("error:",d)

Open in Google Colab

And that's all. Don't forget that you can run this code in Google Colab by clicking button "Open in colab"

References

  1. Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift  [PDF]
    Sergey Ioffe, Christian Szegedy, 2015. Google.
  2. Exponential smoothing  [HTML]