PyTorch代码: https://github.com/shanglianlm0525/CvPytorch 发现关于Batch Normalization Fusion的资料比较少,所以搜集了一些相关资料,并整理如下:
关于Batch Normalization
请参考Batch Normalization学习笔记及其实现 这里不再详细解释。
1 Conv-BatchNorm-Scale-fusion
融合原理: 原来的操作是先卷积后BN,现在将BN归并到卷积操作中,大大减少了计算量。
卷积的公式:
BN的公式:
分别将w, b代入BN公式,可以得到融合后的新卷积的和
:
新的卷积就直接完成原来的卷积和BN的工作。
def autopad(k, p=None): # kernel, padding
# Pad to 'same'
if p is None:
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
return p
class Conv(nn.Module):
# Standard convolution
def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
super(Conv, self).__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
self.bn = nn.BatchNorm2d(c2)
def forward(self, x):
return self.bn(self.conv(x))
def fuseforward(self, x):
return self.conv(x)
def fuse_conv_and_bn(conv, bn):
# https://tehnokv.com/posts/fusing-batchnorm-and-conv/
with torch.no_grad():
# init
fusedconv = nn.Conv2d(conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
bias=True).to(conv.weight.device)
# prepare filters
w_conv = conv.weight.clone().view(conv.out_channels, -1)
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
# prepare spatial bias
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
return fusedconv
def fuse(model): # fuse model Conv2d() + BatchNorm2d() layers
# print('Fusing layers... ')
for m in model.modules():
if type(m) is Conv:
m._non_persistent_buffers_set = set() # pytorch 1.6.0 compatability
m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
m.bn = None # remove batchnorm
m.forward = m.fuseforward # update forward
return
2 InnerProduct-BatchNorm-Scale-fusion
融合原理: 原来的操作是先全连接后BN,现在将BN归并到全连接的计算中。
操作同上。
1 PyTorch 卷积与BatchNorm的融合https://zhuanlan.zhihu.com/p/49329030
2 模型推理加速方法 Batch Norm Fusion 的 PyTorch 实现,可提速 30%https://www.pytorchtutorial.com/batch-norm-fusion-pytorch/
3 pytorch版githubhttps://github.com/MIPT-Oulu/pytorch_bn_fusion
4 caffe版github https://github.com/shanglianlm0525/CNN-Conv-BatchNorm-fusion
5 Fusing batch normalization and convolution in runtime