让模型实现“终生学习”,佐治亚理工学院提出Data-Free的增量学习
关注公众号,发现CV技术之美
0
写在前面
目前的计算机视觉模型在进行增量学习新的知识的时候,就会出现灾难性遗忘的问题。缓解这种遗忘的最有效的方法需要大量重播(replay)以前训练过的数据;但是,当内存限制或数据合法性问题存在时,这种方法就存在一定的局限性。
在本文中,作者研究了无数据类增量学习(DFCIL)的问题,也就是增量学习能够学习新的知识,而不存储生成器或过去任务的训练数据。目前,DFCIL的一种方法是通过倒置学习分类模型的冻结副本,来合成图像用于训练,使得模型能够不忘记以前任务的知识,也不用replay以前训练过的数据。但是,作者通过实验表明了当使用标准蒸馏策略时,这种方法对于常见的类增量benchmark都是无效的。因此,在本文中,作者分析了这种方法失败的原因,并提出了一种新的DFCIL增量蒸馏策略,提供了一个改进的交叉熵训练和重要性加权特征蒸馏。最终作者通过实验表明,在类增量benchmark上,与SOTA DFCIL方法相比,本文提出的方法在精度上提高了25.1%,甚至优于几种需要存储图像的基于replay的方法。01
论文和代码地址
论文地址:https://arxiv.org/abs/2106.09701
代码地址:尚未开源
02
Motivation
目前,计算机视觉的一个局限是,它们通常使用一个包含在部署过程中所有可能遇到的数据的大型数据集,进行脱机训练。然而,现实情况是许多应用程序需要在遇到新的情况和数据后不断更新模型。这就是类增量学习的范式,在学习新任务的时候忘记以前学习到的知识的问题被称为在灾难性遗忘 。目前,比较成功的增量学习方法有一个缺点:它们需要大量的内存来replay以前看到过的或建模的数据,以避免灾难性遗忘问题。
这在很多计算机视觉的应用中也是不现实的,因为:1)许多计算机视觉应用程序都是在设备上的,因此内存有限;
2)在工业界,可能会存在很多不允许被存储的数据(比如用户的隐私信息)。
因此,作者就提出了这样一个问题:计算机视觉系统如何能在不存储数据的情况下增量地学习新信息?作者将这样的设置称为无数据类增量学习(DFCIL)。DFCIL的一种直观方法是同时训练生成模型进行采样以进行replay,以防止忘记以前的知识。但是与分类模型相比,训练生成模型的计算和内存都更密集。因此,作者探索了模型反演图像 合成的概念,就是通过反转已经提供的推理网络,来获得网络中与训练数据具有相似激活作用的图像。这样一来,就不需要训练额外的网络(因为它只需要现有的推理网络)。上图展示了DFCIL增量学习失败的原因(图a),用当前任务的真实图像和代表过去任务的合成图像训练模型时,特征提取模型提取的特征会变成:当前真实图像的特征分布与当前真实图像的特征分布(即使他们不属于同一个类)更接近,与合成图像的特征分布更不接近 ,这就导致了预测时候的偏差。这一现象表明,当训练一个具有两种数据分布的网络时,同时包含语义位移和分布位移,分布位移对特征嵌入有更高的影响。因此,来自以前任务的的测试图像将被识别为新的类,因为模型会更关注于它们的分布,而不是它们的语义内容(这就与分类任务的目标背道而驰了)。为了解决这个问题,作者提出了一种新的类增量学习方法,该方法学习了具有局部分类损失的新任务特征,依赖于重要性加权特征蒸馏和线性分类head微调来分离新任务和过去任务的特征嵌入。作者通过实验表明,在类增量benchmark上,与SOTA DFCIL方法相比,本文提出的方法在精度上提高了25.1%,甚至优于几种需要存储图像的基于replay的方法。
03
方法
3.1. 先验知识-类增量学习
在类增量学习中,一个模型需要学习了对应于M个语义对象类的数据,但这些数据是通过N个task依次暴露给模型的,每个任务中子类都不会重合。我们用来表示任务n中引入的类集,其中表示任务n中对象类的数量。每个类只出现在单个任务中,模型目标就是逐步学习引入的新对象类,并对它们进行分类,同时保留之前学习过的类的知识。为了描述推理模型,我们将表示在i时刻使用任务n的类训练的模型。3.2. Baseline Approach
在本节中,作者基于之前工作,提出了一个Data-Free的用于类增量学习的baseline。
3.2.1. Model-Inversion Image Synthesis
大多数模型反演图像合成方法都是通过直接对先验的鉴别模型进行优化来合成图像。然而,一次优化一个Batch的图像在计算上是效率低下的。因此作者选择使用卷积网络参数化函数用噪声生成合成图像进行近似优化。这就使每个任务只需要训练一次,当前任务结束时就可以直接丢弃。
首先,需要生成多样性的图片,因此作者优化合成了图像的类预测的多样性,以匹配均匀分布。将表示为模型θ对输入x产生的预测类分布,需要使合成样本的平均类预测向量的熵最大化,如下所是(label diversity loss):除了多样性之外,为了在DFCIL中合成有用的图像,图像还需要校准的类置信度、特征统计数据的一致性和局部平滑的潜在空间。对于校准的类置信度 ,作者使用了Content Loss,通过对图像张量的类预测一致性最大化,这样就能对所有输入做出足够confident的预测了。Content Loss的具体计算表示如下所示:
对于特征统计数据的一致性 ,先前的工作发现,模型反演的复杂性会导致特征的分布大大偏离合成图像的分布。因此,合成图像的Batch统计应该与中的Batch Norm层相匹配。基于此,作者进一步提出了stat alignment loss:
对于局部平滑的潜在空间 ,先验知识告诉我们,自然图像在像素空间中比初始噪声更局部平滑。因此作者又提出了一个损失函数smoothness prior loss,这个函数就是生成图像和高斯模糊版本的生成图像的L2距离:
3.2.2. Distilling Synthetic Data for Class-Incremental Learning
在类增量学习中,对合成图像的知识蒸馏通常被用于正则化,迫使它学习,学习的同时,将的知识遗忘减到最小。对于任务,我们从任务期间训练的的冻结副本中合成图像。这些合成图像帮助我们将任务中学习的知识提炼到我们当前的模型中。
在Baseline方法中,作者采用了DeepInversion中使用的蒸馏方法。具体表示为,给定当前的任务数据和合成的蒸馏数据,我们最小化:3.3. Diagnosis: Feature Embedding Prioritizes Domains Over Semantics
为了探究为什么DFCIL的Baseline方法会失败,作者使用度量(MID)分析了嵌入特征之间的表征距离,这种度量用于捕获两个分布样本的平均图像embedding之间的距离。作者将这种度量实例化为Mean Image Distance (MID) score,高分表示不同的特征,低分表示相似的特征。计算如下:
3.4. A New Distillation Strategy for DFCIL
基于上面的分析,作者提出了持续的学习应该在以下几个方面保持平衡:(1)针对新任务的学习特征;(2)最小化超过上一个任务的特征偏移;(3)在embedding空间中分离新的类和以前的类之间的类重叠。
对于上面的三个平衡,(1)和(3)可以通过实现。但是作者认为,通过将其分成两种不同的损失,可以在学习新任务的时候,不区分真实图像和合成图像的特征。根据这个想法,作者提出了一种为DFCIL设计的新的类增量学习方法,该方法独立地解决这些目标。(蓝色箭头表示之前合成的任务数据的计算路径,绿色箭头表示真实的当前任务数据的计算路径,黄色箭头表示真实数据和合成数据的计算路径。)
3.4.1. Learning current task features
作者方法背后的intuition是需要学习当前task的特征的同时,绕过偏向最近task真实数据的特征表示。具体实现上,作者通过只计算在新的 线性分类head上的局部交叉熵分类损失来实现这一点。有了这种模式,作者阻止了模型学习通过domain分离新的和过去的类数据,损失函数如下:
3.4.2. Minimizing feature drift over previous task data
与真实的当前任务图像相比,蒸馏图像属于另一个domain,因此作者寻找了另一个损失函数,直接减轻遗忘的损失函数。要实现这个目标,一个选择是特性蒸馏:
因此,作者提出了一种重要性加权特征蒸馏,它只强化了过去任务数据中最重要的组成部分,同时允许不那么重要的特性来适应新任务。表示如下:
3.4.3. Separating Current and Past Decision Boundaries
最后,模型需要分离当前类和过去类的决策边界,而不允许特征空间来区分真实数据和合成数据。作者通过用交叉熵损失函数来fine-tuning线性分类head来实现。除了线性分类head之外,这个损失函数并不会更新中的任何参数:
3.4.4. Final Objective
最终模型的损失函数为上述损失函数之和,如下所示:
04
实验
4.1. DFCIL (CIFAR-100 )
4.2. CIL with Replay Data (CIFAR-100 )
4.3. Ablation Study(CIFAR-100 )
4.4. DFCIL (ImageNet)
05
总结
在本文中,作者表明现有的类增量学习方法在使用真实训练数据学习新任务和使用合成蒸馏数据保存过去的知识时,performance较差。因此,作者提出了一种新的方法来实现了无数据类增量学习的SOTA性能,并与基于replay的SOTA方法性能相当。
作者提出无数据类增量学习是希望消除在类增量学习中存储回放数据的需要,使计算机视觉的广泛和实际应用成为可能。不存储数据的增量学习解决方案,将对计算机视觉应用产生直接影响,进一步促进计算机视觉任务的落地应用。▊ 作者简介
厦门大学人工智能系20级硕士
研究领域:FightingCV公众号运营者,研究方向为多模态内容理解,专注于解决视觉模态和语言模态相结合的任务,促进Vision-Language模型的实地应用。
知乎/公众号:FightingCV
END,入群👇备注:CV