写在前面
最近对数据集蒸馏比较感兴趣,抽时间看了下这篇经典的数据蒸馏论文《Dataset Distillation》,它是属于knowledge distillation领域的工作。
一、动机
本文提出的方法是数据集蒸馏(Dataset Distillation)
- 从大的训练数据中蒸馏知识到小的数据集
- 小的数据集不需要与原始的大的训练数据分布相同
- 只要在小的数据集上训练几步梯度下降就能达到和原始数据相近的模型效果
模型蒸馏(model层面)的目标是从一个复杂的模型中蒸馏知识到小的模型上。
本文考虑的是数据集上的蒸馏(dataset层面),具体来说,我们会固定住模型,然后尝试从较大的训练数据集中蒸馏知识到小的数据集上。
核心目的是将原始的大数据集压缩成一个小的数据集(不需要来自训练集的分布),并且在这个小数据集上训练模型的效果和原始较大数据集上的训练效果是接近的。
例如,可以用60000张MNIST训练集蒸馏出10张distilled的图片(每个类一张),在这10张图片只需再训练几轮,就能到达和原始效果接近的效果。
二、背景知识
2015 Hinton等人提出了network distillation(model compression),本文我们不蒸馏模型,我们蒸馏数据集。
通常来说如果你小数据的分布和真正测试集的分布不同,是很难训练出一个好的模型的,但是本文的工作表明,这完全是可能的。
本文提出了一种新的优化方法,尽管现在只有很小的数据集,但他不仅能够抓住原始大数据集的信息,而且只要几步梯度下降就能训练好模型,并且在真正的测试集上效果良好。
相关工作:简单列一下相关的方向
- Knowledge distillation
- Dataset pruning,core-set construction,and instance selection
- Gradient-based hyperparameter optimization
- Understanding datasets
三、方法
3.1 训练 x ~ \tilde{\mathbf{x}} x~和 η ~ \tilde \eta η~的过程
传统的模型训练会使用随机梯度下降进行参数优化,假设现在进行第t次参数更新,使用的minibatch的训练集为 x t = { x t , j } j = 1 n \mathbf{x}_{t}=\left\{x_{t, j}\right\}_{j=1}^{n} xt={xt,j}j=1n
θ t + 1 = θ t − η ∇ θ t ℓ ( x t , θ t ) \theta_{t+1}=\theta_{t}-\eta \nabla_{\theta_{t}} \ell\left(\mathbf{x}_{t}, \theta_{t}\right) θt+1=θt−η∇θtℓ(xt,θt)
通常来说这种训练方式,需要更新上万次参数才能收敛。
本文的目标是学习到一小部分的合成的distilled的训练集 x ~ = { x ~ i } i = 1 M \tilde{\mathbf{x}}=\left\{\tilde{x}_{i}\right\}_{i=1}^{M} x~={x~i}i=1M,其中M远小于总的训练集数量N,以及学习对应的学习率 η ~ \tilde \eta η~,使得只要一次参数更新就能得到一个在真实测试集上效果很好的模型参数。(为什么要学习 η ~ \tilde \eta η~?因为作者想要通过少量的梯度下降便得到一个比较好的模型,因此学习率既不能太大,也不能太小,因而需要学习获得)
θ 1 = θ 0 − η ~ ∇ θ 0 ℓ ( x ~ , θ 0 ) \theta_{1}=\theta_{0}-\color{red}\tilde{\eta} \color{black} \nabla_{\theta_{0}} \ell\left(\color{red} \tilde{\mathbf{x}}\color{black} , \theta_{0}\right) θ1=θ0−η~∇θ0ℓ(x~,θ0)
那么如何学习 x ~ \tilde{\mathbf{x}} x~和 η ~ \tilde \eta η~呢?
很简单,我们希望通过 x ~ \tilde{\mathbf{x}} x~和 η ~ \tilde \eta η~得到的 θ 1 \theta_1 θ1,能够使 ℓ ( x , θ 1 ) \ell\left(\mathbf{x}, \theta_{1}\right) ℓ(x,θ1)最小,其中 x x x是原始的大训练集。
因此对应的优化目标如下:
x
~
∗
,
η
~
∗
=
arg
min
x
~
,
η
~
L
(
x
~
,
η
~
;
θ
0
)
=
arg
min
x
~
,
η
~
ℓ
(
x
,
θ
1
)
=
arg
min
x
~
,
η
~
ℓ
(
x
,
θ
0
−
η
~
∇
θ
0
ℓ
(
x
~
,
θ
0
)
)
\begin{aligned} \tilde{\mathbf{x}}^{*}, \tilde{\eta}^{*} &= \underset{\tilde{\mathbf{x}}, \tilde{\eta}}{\arg \min } \mathcal{L}\left(\tilde{\mathbf{x}}, \tilde{\eta} ; \theta_{0}\right)\\ &=\underset{\tilde{\mathbf{x}}, \tilde{\eta}}{\arg \min } \ell\left(\mathbf{x}, \theta_{1}\right) \\ &=\underset{\tilde{\mathbf{x}}, \tilde{\eta}}{\operatorname{\arg \min}} \ell\left(\mathbf{x}, \theta_{0}-\color{red}\tilde{\eta} \color{black} \nabla_{\theta_{0}} \ell\left(\color{red} \tilde{\mathbf{x}}\color{black}, \theta_{0}\right)\right) \end{aligned}
x~∗,η~∗=x~,η~argminL(x~,η~;θ0)=x~,η~argminℓ(x,θ1)=x~,η~argminℓ(x,θ0−η~∇θ0ℓ(x~,θ0))
说明1:在这种优化方式下,我们学习的参数只有 x ~ \tilde{\mathbf{x}} x~和 η ~ \tilde \eta η~,通过随机梯度下降求解 x ~ \tilde{\mathbf{x}} x~和 η ~ \tilde \eta η~的过程和普通的优化没有什么区别,也是要进行上万次更新。
说明2:对于数据 x ~ \tilde{x} x~的其它部分,例如标签,只把它固定而不进行学习。
3.2 在合成数据集 x ~ \tilde{\mathbf{x}} x~上训练模型(fix init)
当我们训练好得到合成数据集 x ~ \tilde{\mathbf{x}} x~和对应的学习率 η ~ \tilde \eta η~后,我们就可以在这个合成数据集 x ~ \tilde{\mathbf{x}} x~上训练模型了。
那么这个模型的初始化参数应该是什么呢?
3.3 训练 x ~ \tilde{\mathbf{x}} x~和 η ~ \tilde \eta η~时用随机初始化参数 θ 0 \theta_0 θ0(random init)
为了解决上面提到的问题,作者提出,训练 x ~ \tilde{\mathbf{x}} x~和 η ~ \tilde \eta η~的时候不用固定的 θ 0 \theta_0 θ0,而是每次从分布 p ( θ 0 ) p(\theta_0) p(θ0)中随机采样。
此时的优化目标如下:
x
~
∗
,
η
~
∗
=
arg
min
x
~
,
η
~
E
θ
0
∼
p
(
θ
0
)
L
(
x
~
,
η
~
;
θ
0
)
\tilde{\mathbf{x}}^{*}, \tilde{\eta}^{*}=\underset{\tilde{\mathbf{x}}, \tilde{\eta}}{\arg \min } \mathbb{E}_{\theta_{0} \sim p\left(\theta_{0}\right)} \mathcal{L}\left(\tilde{\mathbf{x}}, \tilde{\eta} ; \theta_{0}\right)
x~∗,η~∗=x~,η~argminEθ0∼p(θ0)L(x~,η~;θ0)
实验表明,在随机初始化参数 θ 0 \theta_0 θ0的条件下得到的合成数据集 x ~ \tilde x x~后,我们在合成数据集 x ~ \tilde x x~训练模型时可以随机初始化参数,而且效果也不错(不过还是没有固定 θ 0 \theta_0 θ0的效果好),另外在random initialization条件下得到distilled images通常包含有一定的信息,因为合成数据编码了每个类别的判别特征,如实验部分Figure3所示。
3.4 多步参数更新
前面介绍的从
θ
0
\theta_0
θ0到
θ
1
\theta_1
θ1我们只进行了signle GD step,实际上这一部分可以改成多步。只需要将Algorithm1中第6行改成多步即可。
θ
i
+
1
=
θ
i
−
η
~
i
∇
θ
i
ℓ
(
x
~
i
,
θ
i
)
\theta_{i+1}=\theta_{i}-\tilde{\eta}_{i} \nabla_{\theta_{i}} \ell\left(\tilde{\mathbf{x}}_{i}, \theta_{i}\right)
θi+1=θi−η~i∇θiℓ(x~i,θi)
每一步使用不同的distilled data x ~ i \tilde x_i x~i和 η ~ i \tilde \eta_i η~i
文章中还用了优化算法来加快梯度回传的过程。
3.5 不同的初始化参数 θ 0 \theta_0 θ0的方式
除了固定初始化模型参数和随机初始化模型参数外,文章还提出了使用其他任务中预训练好的模型参数来模型的参数,所以共有以下4种方式构建初始化参数,其中最后一种方式的效果是最好的。
四、实验结果
数据集:MNIST、CIFAR10
Fixed initialization得到的distilled images:
Random initialization得到的distilled images:
五、本文总结
这篇数据集蒸馏的文章我觉得非常有意思,刷新了我之前的认知,值得反复阅读。
未来工作:
- 将数据集蒸馏应用到大规模的图片数据集(ImageNet)以及其他类型的数据上(如语音、文本)
- 我们的方法对初始化的分布比较敏感,我们会研究其他的初始化策略。
相关资料
- 论文 | 《dataset distillation》数据集知识蒸馏文章解读
- Dataset Distillation论文笔记