0
点赞
收藏
分享

微信扫一扫

tf.keras.layers.BatchNormalization函数

phpworkerman 2022-04-15 阅读 58
python

函数原型

tf.keras.layers.BatchNormalization(axis=-1, 
								   momentum=0.99, 
								   epsilon=0.001, 
								   center=True, 
								   scale=True, 
								   beta_initializer='zeros', 
								   gamma_initializer='ones',
								   moving_mean_initializer='zeros', 
								   moving_variance_initializer='ones', 
								   beta_regularizer=None, 
								   gamma_regularizer=None,
								   beta_constraint=None, 
								   gamma_constraint=None, 
								   **kwargs
								   )

函数说明

批量标准化层应用了一种转换,使得数据的均值趋于0,标准差趋于1。该层实现的公式如下图所示。输入向量x,输出向量y。
在这里插入图片描述

批标准化层在训练和推理期间的工作方式不同。

在训练期间,也就是使用fit()函数或者调用参数模型training=True时,对于输入的数据batch,输出为gamma * (batch - mean(batch)) / sqrt(var(batch) + epsilon) + beta。

在推理期间,也就是使用predict()函数、evaluate()函数或者调用参数模型training=False时,对于输入的数据batch,输出为gamma * (batch - moving_mean) / sqrt(moving_var + epsilon) + beta,其中moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum),moving_var = moving_var * momentum + var(batch) * (1 - momentum)。moving_mean和moving_var这两个变量只在每个训练期间结束后更新,在推理期间不更新。

另外可通过传递scale=False和center=False来禁用gamma、beta这两个参数。

正常情况下,大多数参数都不会被用到,所以只需知道这一层是用来干嘛的,什么时候需要用到这一层就行。可以参考一下这个链接。
https://zhuanlan.zhihu.com/p/24810318

函数用法

BatchNormalization 广泛用于 Keras 内置的许多高级卷积神经网络架构,比如 ResNet50、Inception V3 和 Xception。BatchNormalization 层通常在卷积层密集连接层之后使用。

原始论文讲在CNN中一般应作用与非线性激活函数之前,但是,在caffenet-benchmark-batchnorm中,作者基于caffenet在ImageNet2012上做了如下对比实验:
在这里插入图片描述

从上图可以看出,放在前后的差异并不是很大,甚至放在激活函数之后效果可能会更好。

# 放在非线性激活函数之前
model.add(tf.keras.layers.Conv2D(256, 2, 2)))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Activation('relu'))

# 放在激活函数之后
model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu'))
model.add(tf.keras.layers.BatchNormalization())

举报

相关推荐

0 条评论