1.简介
PaddleSlim是一个专注于深度学习模型压缩的工具库,提供剪裁、量化、蒸馏、和模型结构搜索等模型压缩策略,帮助用户快速实现模型的小型化。
2.卷积Filter剪裁
对卷积网络的通道进行一次剪裁。剪裁一个卷积层的通道,是指剪裁该卷积层输出的通道。卷积层的权重形状为 [output_channel, input_channel, kernel_size, kernel_size] ,通过剪裁该权重的第一纬度达到剪裁输出通道数的目的。
实际剪裁时要考虑到每层通道的敏感度,一般剪裁后要在验证集上测试精度得到敏感度,敏感度低的剪裁掉来压缩模型。
2.1paddleSlim的API
在paddleslim中关于动态图的剪裁接口主要有三个:
- L1NormFilterPruner        该剪裁器按 
Filters的l1-norm统计值对单个卷积层内的Filters的重要性进行排序,并按指定比例剪裁掉相对不重要的Filters。对Filters的剪裁等价于剪裁卷积层的输出通道数。 - L2NormFilterPruner        该剪裁器按 
Filters的l2-norm统计值对单个卷积层内的Filters的重要性进行排序,并按指定比例剪裁掉相对不重要的Filters。对Filters的剪裁等价于剪裁卷积层的输出通道数。 - FPGMFilterPruner          该剪裁器按论文 Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration <https://arxiv.org/abs/1811.00250>_ 中的统计方法对单个卷积层内的 
Filters的重要性进行排序,并按指定比例剪裁掉相对不重要的Filters。对Filters的剪裁等价于剪裁卷积层的输出通道数。 
2.2示例
导入各种包
import paddle
import paddle.vision.models as models
from paddle.static import InputSpec as Input
from paddle.vision.datasets import Cifar10
import paddle.vision.transforms as T
from paddleslim.dygraph import L1NormFilterPruner 
网络定义和数据集加载
net = models.mobilenet_v1()
inputs = Input(shape=[None, 3, 32, 32], dtype='float32', name='image')
labels = Input(shape=[None, 1], dtype='int64', name='label')
optmizer = paddle.optimizer.Momentum(learning_rate=0.1, parameters=net.parameters())
model = paddle.Model(net, inputs, labels)
model.prepare(
    optimizer=optmizer, 
    loss=paddle.nn.CrossEntropyLoss(),
    metrics=paddle.metric.Accuracy(topk=(1, 5))
)
transforms = T.Compose([
    T.Transpose(),
    T.Normalize([127.5], [127.5])
])
train_dataset = Cifar10(mode='train', transform=transforms)
test_dataset = Cifar10(mode='train', transform=transforms) 
训练模型
model.fit(train_dataset, epochs=2, batch_size=128, verbose=1) 
计算剪裁之前的模型相关信息。使用paddle.flops函数
flops = paddle.flops(net, input_size=[1, 3, 32, 32], print_detail=True) 
结果会返回模型的详细参数量Params和Flops,此外在Layer Name中的名字就是调用剪裁器pruner要给出剪裁参数名。
+-----------------------+-----------------+-----------------+---------+---------+
|       Layer Name      |   Input Shape   |   Output Shape  |  Params |  Flops  |
+-----------------------+-----------------+-----------------+---------+---------+
|        conv2d_0       |  [1, 3, 32, 32] | [1, 32, 16, 16] |   864   |  221184 |
|     batch_norm2d_0    | [1, 32, 16, 16] | [1, 32, 16, 16] |   128   |  16384  |
|        re_lu_0        | [1, 32, 16, 16] | [1, 32, 16, 16] |    0    |    0    |
|        conv2d_1       | [1, 32, 16, 16] | [1, 32, 16, 16] |   288   |  73728  |
|     batch_norm2d_1    | [1, 32, 16, 16] | [1, 32, 16, 16] |   128   |  16384  |
|        re_lu_1        | [1, 32, 16, 16] | [1, 32, 16, 16] |    0    |    0    |
|        conv2d_2       | [1, 32, 16, 16] | [1, 64, 16, 16] |   2048  |  524288 |
|     batch_norm2d_2    | [1, 64, 16, 16] | [1, 64, 16, 16] |   256   |  32768  |
|        re_lu_2        | [1, 64, 16, 16] | [1, 64, 16, 16] |    0    |    0    |
|        conv2d_3       | [1, 64, 16, 16] |  [1, 64, 8, 8]  |   576   |  36864  |
|     batch_norm2d_3    |  [1, 64, 8, 8]  |  [1, 64, 8, 8]  |   256   |   8192  |
|        re_lu_3        |  [1, 64, 8, 8]  |  [1, 64, 8, 8]  |    0    |    0    |
|        conv2d_4       |  [1, 64, 8, 8]  |  [1, 128, 8, 8] |   8192  |  524288 |
|     batch_norm2d_4    |  [1, 128, 8, 8] |  [1, 128, 8, 8] |   512   |  16384  |
|        re_lu_4        |  [1, 128, 8, 8] |  [1, 128, 8, 8] |    0    |    0    |
|        conv2d_5       |  [1, 128, 8, 8] |  [1, 128, 8, 8] |   1152  |  73728  |
|     batch_norm2d_5    |  [1, 128, 8, 8] |  [1, 128, 8, 8] |   512   |  16384  |
|        re_lu_5        |  [1, 128, 8, 8] |  [1, 128, 8, 8] |    0    |    0    |
|        conv2d_6       |  [1, 128, 8, 8] |  [1, 128, 8, 8] |  16384  | 1048576 |
|     batch_norm2d_6    |  [1, 128, 8, 8] |  [1, 128, 8, 8] |   512   |  16384  |
|        re_lu_6        |  [1, 128, 8, 8] |  [1, 128, 8, 8] |    0    |    0    |
|        conv2d_7       |  [1, 128, 8, 8] |  [1, 128, 4, 4] |   1152  |  18432  |
|     batch_norm2d_7    |  [1, 128, 4, 4] |  [1, 128, 4, 4] |   512   |   4096  |
|        re_lu_7        |  [1, 128, 4, 4] |  [1, 128, 4, 4] |    0    |    0    |
|        conv2d_8       |  [1, 128, 4, 4] |  [1, 256, 4, 4] |  32768  |  524288 |
|     batch_norm2d_8    |  [1, 256, 4, 4] |  [1, 256, 4, 4] |   1024  |   8192  |
|        re_lu_8        |  [1, 256, 4, 4] |  [1, 256, 4, 4] |    0    |    0    |
|        conv2d_9       |  [1, 256, 4, 4] |  [1, 256, 4, 4] |   2304  |  36864  |
|     batch_norm2d_9    |  [1, 256, 4, 4] |  [1, 256, 4, 4] |   1024  |   8192  |
|        re_lu_9        |  [1, 256, 4, 4] |  [1, 256, 4, 4] |    0    |    0    |
|       conv2d_10       |  [1, 256, 4, 4] |  [1, 256, 4, 4] |  65536  | 1048576 |
|    batch_norm2d_10    |  [1, 256, 4, 4] |  [1, 256, 4, 4] |   1024  |   8192  |
|        re_lu_10       |  [1, 256, 4, 4] |  [1, 256, 4, 4] |    0    |    0    |
|       conv2d_11       |  [1, 256, 4, 4] |  [1, 256, 2, 2] |   2304  |   9216  |
|    batch_norm2d_11    |  [1, 256, 2, 2] |  [1, 256, 2, 2] |   1024  |   2048  |
|        re_lu_11       |  [1, 256, 2, 2] |  [1, 256, 2, 2] |    0    |    0    |
|       conv2d_12       |  [1, 256, 2, 2] |  [1, 512, 2, 2] |  131072 |  524288 |
|    batch_norm2d_12    |  [1, 512, 2, 2] |  [1, 512, 2, 2] |   2048  |   4096  |
|        re_lu_12       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |    0    |    0    |
|       conv2d_13       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |   4608  |  18432  |
|    batch_norm2d_13    |  [1, 512, 2, 2] |  [1, 512, 2, 2] |   2048  |   4096  |
|        re_lu_13       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |    0    |    0    |
|       conv2d_14       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |  262144 | 1048576 |
|    batch_norm2d_14    |  [1, 512, 2, 2] |  [1, 512, 2, 2] |   2048  |   4096  |
|        re_lu_14       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |    0    |    0    |
|       conv2d_15       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |   4608  |  18432  |
|    batch_norm2d_15    |  [1, 512, 2, 2] |  [1, 512, 2, 2] |   2048  |   4096  |
|        re_lu_15       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |    0    |    0    |
|       conv2d_16       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |  262144 | 1048576 |
|    batch_norm2d_16    |  [1, 512, 2, 2] |  [1, 512, 2, 2] |   2048  |   4096  |
|        re_lu_16       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |    0    |    0    |
|       conv2d_17       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |   4608  |  18432  |
|    batch_norm2d_17    |  [1, 512, 2, 2] |  [1, 512, 2, 2] |   2048  |   4096  |
|        re_lu_17       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |    0    |    0    |
|       conv2d_18       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |  262144 | 1048576 |
|    batch_norm2d_18    |  [1, 512, 2, 2] |  [1, 512, 2, 2] |   2048  |   4096  |
|        re_lu_18       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |    0    |    0    |
|       conv2d_19       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |   4608  |  18432  |
|    batch_norm2d_19    |  [1, 512, 2, 2] |  [1, 512, 2, 2] |   2048  |   4096  |
|        re_lu_19       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |    0    |    0    |
|       conv2d_20       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |  262144 | 1048576 |
|    batch_norm2d_20    |  [1, 512, 2, 2] |  [1, 512, 2, 2] |   2048  |   4096  |
|        re_lu_20       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |    0    |    0    |
|       conv2d_21       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |   4608  |  18432  |
|    batch_norm2d_21    |  [1, 512, 2, 2] |  [1, 512, 2, 2] |   2048  |   4096  |
|        re_lu_21       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |    0    |    0    |
|       conv2d_22       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |  262144 | 1048576 |
|    batch_norm2d_22    |  [1, 512, 2, 2] |  [1, 512, 2, 2] |   2048  |   4096  |
|        re_lu_22       |  [1, 512, 2, 2] |  [1, 512, 2, 2] |    0    |    0    |
|       conv2d_23       |  [1, 512, 2, 2] |  [1, 512, 1, 1] |   4608  |   4608  |
|    batch_norm2d_23    |  [1, 512, 1, 1] |  [1, 512, 1, 1] |   2048  |   1024  |
|        re_lu_23       |  [1, 512, 1, 1] |  [1, 512, 1, 1] |    0    |    0    |
|       conv2d_24       |  [1, 512, 1, 1] | [1, 1024, 1, 1] |  524288 |  524288 |
|    batch_norm2d_24    | [1, 1024, 1, 1] | [1, 1024, 1, 1] |   4096  |   2048  |
|        re_lu_24       | [1, 1024, 1, 1] | [1, 1024, 1, 1] |    0    |    0    |
|       conv2d_25       | [1, 1024, 1, 1] | [1, 1024, 1, 1] |   9216  |   9216  |
|    batch_norm2d_25    | [1, 1024, 1, 1] | [1, 1024, 1, 1] |   4096  |   2048  |
|        re_lu_25       | [1, 1024, 1, 1] | [1, 1024, 1, 1] |    0    |    0    |
|       conv2d_26       | [1, 1024, 1, 1] | [1, 1024, 1, 1] | 1048576 | 1048576 |
|    batch_norm2d_26    | [1, 1024, 1, 1] | [1, 1024, 1, 1] |   4096  |   2048  |
|        re_lu_26       | [1, 1024, 1, 1] | [1, 1024, 1, 1] |    0    |    0    |
| adaptive_avg_pool2d_0 | [1, 1024, 1, 1] | [1, 1024, 1, 1] |    0    |   2048  |
|        linear_0       |    [1, 1024]    |    [1, 1000]    | 1025000 | 1024000 |
+-----------------------+-----------------+-----------------+---------+---------+
Total Flops: 12817920     Total Params: 4253864 
评估精度
model.evaluate(test_dataset, batch_size=128, verbose=1) 
剪裁前的精度为:0.937
{'loss': [1.3092904],  'acc': 0.93706} 
剪裁 。对网络模型两个不同的网络层按照参数名分别进行比例为50%,60%的裁剪。
pruner = L1NormFilterPruner(net, [1, 3, 32, 32])
pruner.prune_vars({'conv2d_22.w_0':0.5, 'conv2d_20.w_0':0.6}, axis=0) 
计算剪裁之后的flops
flops = paddle.flops(net, input_size=[1, 3, 32, 32], print_detail=True) 
这里给出部分结果,剪裁的conv2d_20和conv2d_22比例为0.5,0.6
+-----------------------+-----------------+-----------------+---------+---------+
|       Layer Name      |   Input Shape   |   Output Shape  |  Params |  Flops  |
+-----------------------+-----------------+-----------------+---------+---------+
|       conv2d_20       |  [1, 512, 2, 2] |  [1, 205, 2, 2] |  104960 |  419840 |
|       conv2d_22       |  [1, 205, 2, 2] |  [1, 256, 2, 2] |  52480  |  209920 |
+-----------------------+-----------------+-----------------+---------+---------+
Total Flops: 11067556     Total Params: 3615301 
剪裁后的精度
model.evaluate(test_dataset, batch_size=128, verbose=1) 
精度由0.93降为0.76 。对模型进行裁剪会导致模型精度有一定程度下降。
{'loss': [2.2277398], 'acc_top5': 0.76516} 
 对模型进行微调会有助于模型恢复原有精度。 以下代码对裁剪过后的模型进行评估后执行了一个epoch的微调,再对微调过后的模型重新进行评估: 
optimizer = paddle.optimizer.Momentum(
    learning_rate=0.1,
    parameters=net.parameters())
model.prepare(
    optimizer,
    paddle.nn.CrossEntropyLoss(),
    paddle.metric.Accuracy(topk=(1, 5)))
model.fit(train_dataset, epochs=1, batch_size=128, verbose=1) 
评估
model.evaluate(test_dataset, batch_size=128, verbose=1) 
微调后精度恢复到0.94,比剪裁prune前还要高是因为剪裁前模型没有调到最优。
{'loss': [1.2696353], 'acc_top1': 0.54044, 'acc_top5': 0.94396} 
2.2.1剪裁前后对比
| 剪裁前 | 剪裁后 | |
|---|---|---|
| params | 4253864 | 3615301 | 
| flops | 12817920 | 11067556 | 
| accuracy | 0.937 | 0.94 | 
剪裁前的评估每个样本耗时:1s/step 剪裁后的评估每个样本耗时:905ms/step










