论文标题
使用最佳运输混合模型进行持续学习
Continual Learning with Optimal Transport based Mixture Model
论文作者
论文摘要
在线课程增量学习(CIL)是持续学习(CL)的挑战性环境,其中新任务的数据以传入流到达,在线学习模型需要处理传入的数据流而无需重新审视以前的数据流。现有的作品使用了与传入数据流相应的单一质心来表征类。当类的传入数据流是自然的多模式时,这种方法可能会暴露局限性。为了解决这个问题,在这项工作中,我们首先根据成熟最佳运输理论(OT-MM)的良好特性提出了一种在线混合模型学习方法。具体而言,混合模型的质心和协方差矩阵根据传入的数据流进行逐步调整。优点是两个倍:(i)我们可以通过使用ot-mm生成的每个类别使用质心来表征更准确的复杂数据流,并且(ii)我们可以在进行推理时估计每个类别对每个班级的相似性。此外,为了打击CIL场景中的灾难性遗忘,我们进一步提出了动态保存。特别是,在跨数据流执行动态保存技术之后,旧任务中类的潜在表示变得更加浓缩,并且彼此之间更加分开。与收缩特征提取器一起,该技术促进了模型减轻灾难性遗忘的模型。现实世界数据集的实验结果表明,我们提出的方法可以显着优于当前最新基准。
Online Class Incremental learning (CIL) is a challenging setting in Continual Learning (CL), wherein data of new tasks arrive in incoming streams and online learning models need to handle incoming data streams without revisiting previous ones. Existing works used a single centroid adapted with incoming data streams to characterize a class. This approach possibly exposes limitations when the incoming data stream of a class is naturally multimodal. To address this issue, in this work, we first propose an online mixture model learning approach based on nice properties of the mature optimal transport theory (OT-MM). Specifically, the centroids and covariance matrices of the mixture model are adapted incrementally according to incoming data streams. The advantages are two-fold: (i) we can characterize more accurately complex data streams and (ii) by using centroids for each class produced by OT-MM, we can estimate the similarity of an unseen example to each class more reasonably when doing inference. Moreover, to combat the catastrophic forgetting in the CIL scenario, we further propose Dynamic Preservation. Particularly, after performing the dynamic preservation technique across data streams, the latent representations of the classes in the old and new tasks become more condensed themselves and more separate from each other. Together with a contraction feature extractor, this technique facilitates the model in mitigating the catastrophic forgetting. The experimental results on real-world datasets show that our proposed method can significantly outperform the current state-of-the-art baselines.