当前位置:首页 期刊杂志

基于VAE和注意力机制的小样本图像分类方法

时间:2024-05-04

郑欣悦 黄永辉

(中国科学院国家空间科学中心复杂航天系统电子信息技术重点实验室 北京 100190)(中国科学院大学 北京 100049)

0 引 言

近年来,人工智能技术研究飞速发展,深度学习算法已在图像识别领域取得了突破性的进展,但算法也逐渐显露出泛化能力差、所需训练数据大等缺点。目前,以CNN为基础的图像识别方法通常需要海量的训练数据和充足的迭代次数,才可对特定的图像类别进行精准的分类。然而,实际应用中研究者常面临数据稀缺的情况,比如罕见物种图片、珍贵的医疗诊断图片、稀有遥感图像等,采集这些数据的难度大且成本高。而少量的样本通常不足以训练出一个较好的深度神经网络。因此,如何实现小样本图像识别成为了计算机视觉领域的重要研究方向。

针对小样本学习问题,深度学习领域存在着许多不同的解决方案,其中元学习方法取得了尤为显著的成效。元学习(Meta-learning)是指导分类器学会如何学习的过程。元学习器在有限的样例中对结构基础层次和参数空间进行优化,以获得跨任务泛化性能[1],具备小样本学习的能力。训练完成的元学习器可以仅根据1至5个输入-输出样例对新的测试样本进行分类。

目前元学习的方法可以归类为以下几种:基于记忆存储的方法[2,9]通过权重更新来调整偏差,并不断地从记忆中学习。Santoro等[2-3]利用神经图灵机引入的外部存储器来实现短期记忆并在标签和输入图像之间建立连接,使输入能够与存储器中的相关图像进行比较,以实现更好的预测。基于梯度的方法[4-5]通常通过训练额外的网络来预测分类器更新策略,如Larochelle等[5]提出训练LSTM优化器以学习分类器网络的参数优化规则。关系网络[6]和匹配网络[7]采用度量学习的思想,不再使用人工设计的指标,而是完全利用神经网络来学习深度距离度量。Finn等[8]提出了一种称为MAML模型无关的元学习方法。该方法的基本思想是同时启动多个任务,然后获取不同学习任务的合成梯度方向来更新神经网络。这样的优化方式能找到最适合网络的初始化位置,这里的初始化位置被定义为:仅通过几个小样本的训练可以调整到最好表现。Reptile[10]是OpenAI提出的简化版MAML算法,MAML需要在反向传播中计算二阶导数,而Reptile只需要计算一阶导数,消耗更少的计算资源且更易于实现。

虽然上述方法取得了令人瞩目的成果,但普遍存在两个缺陷:一是算法引入人为设计的规则来约束学习;二是需要更多额外的存储空间对经验进行存储,并且没有提供将知识转移到其他任务的理论手段。因此本文提出一种结合表征学习和注意力机制[18-19]的元学习方法VAE-ATTN。表征模块利用过去的知识,将高维图像数据表达为有意义的高级表征;注意力模块引导学习器关注关键特征,以快速适应新的学习任务。

VAE-ATTN算法首先运用变分自编码器VAE[11-12]通过无监督学习方法获取各个任务内部共享的特征。VAE的编码器保留预训练后的网络模型参数,将提取的低维高级表征迁移到不同的识别任务中。同时,在通道维度引入注意力机制,通过计算概率分布选择性加强对当前学习任务更重要的特征。本文使用Reptile元学习算法作为基准算法。实验结果表明,VAE-ATTN算法整体性能优于MAML、MatchingNets、Meta-LSTM等对比算法,验证了有效的表征学习和注意力机制的结合能获得更加精准的小样本分类结果。

1 β-VAE工作原理

本文使用变分自编码器进行表征学习,表征学习的目标是从数据中自动学习到从原始数据到数据表征之间的映射。VAE作为深度神经网络,由编码器和解码器构成。如图1所示,VAE本质是提取数据的隐特征,构建从隐特征到生成目标的模型。编码器从原始数据中提取潜在的合理变量,再对编码结果加上高斯噪声加以约束,使之成为服从高斯分布的隐含特征。解码器构建的模型将隐特征映射到重新生成的概率分布中,重构的分布需尽量与原始分布相同。

图1 VAE的工作原理

网络有两个组件:具有参数φ的编码器网络E和具有参数θ的解码器D,其损失函数为:

L(φ,θ,x)=Eqφ(z|x)[logpθ(x|z)]-DKL(qφ(z|x)‖pθ(z))

(1)

式中:qφ(z|x)表示从数据空间到隐含空间的编码器;pθ(x|z)表示从隐含空间到数据空间的解码器。

损失函数由两方面构成:式(1)第一项为重构误差,驱使重构的pθ(x|z)分布更接近于输入分布pθ(x);第二项旨在减小KL散度,驱使qφ(z|x)更接近于先验分布pθ(z)。为了实现这种重构,VAE将捕捉到可以代表原始输入数据的最重要的特征因素。

特别地,我们尝试用VAE的变体β-VAE[13-15]进行实验。β-VAE引入解缠性先验[21],假设数据是基于互相独立的因素生成的,因此可以用表征中不同的独立变量表示这些因素。该解缠性先验可促进编码器学习数据简洁的抽象表示,从而用于各种下游任务并提升样本效率。

如式(2)所示,β-VAE引入了一个可调节的超参数β,它可控制隐变量的维度以及重建精度之间的平衡,同时高斯先验的各向同性性质也给学习的后验带来了隐形的约束。β变化会改变训练期间学习程度,从而鼓励不同的学习表征,实验中需要调整的值以促进使用解缠后的表征。

L(φ,θ,x)=Eqφ(z|x)[logpθ(x|z)]-βDKL(qφ(z|x)‖pθ(z))

(2)

VAE的无监督学习阶段需要编码器对输入数据降维,并从中提取通用而高级的表征,以适用于小样本学习中不同图像类别的一系列任务分布。从元学习的角度处理这个问题,将目标定义为一个有效的学习过程,可以从无标记数据转移到少标记样本的任务。

2 注意力模块学习方法

Bengio等[21]提出具有适合特定任务和数据域的表征可以显著提高训练模型的学习成功率和稳健性。因此,本文对VAE提取的高级表征构建注意力机制,使元学习器能在全局信息中关注更有利于当前学习任务的目标表征。自注意机制与人类视觉注意力机制起着类似的作用,从大量的信息中筛选出部分关键的信息,并聚焦到这些重要的信息上。

图2阐述了注意力模型的内部结构。该模块通过分析输入数据的总特征,捕获通道间依赖关系,预测通道重要性,以此选择性地强调某些特征。

图2 注意力模型的网络结构以及相应特征的维度

根据预训练过编码器产生的隐特征γ构建注意力模块的输入,γ∈Rb×h×w×c,其中b为批大小(Batch size),h和w为特征图的长和宽,c是通道数。由式(3)-式(6)所示,Q和K由输入特征γ通过1×1卷积的跨通道信息整合而得的新的特征图,并将维度变换为Rx×c,其中x=h×w,接着在Q和K的转置之间执行矩阵乘法,最后使用softmax函数进行归一化,得到维度为c×c注意力概率分布αji。这样设计的意义在于计算γ的每个通道数之间的影响力权重,可以突出关键特征图的作用,减少冗余特征对整体分类性能的影响。

Q=reshape(FCNN(γ;θ1))

(3)

K=reshape(FCNN(γ;θ2))

(4)

V=reshape(γ)

(5)

(6)

最后,将权重系数αij与原始特征进行加权求和,再用尺度系数β加以调整,即可获得辨别性高的特征表达Oj:

(7)

其中:β初始化为0,在学习的过程中逐渐分配到更大的权重。

该注意力模块能自适应地整合局部特征并明确全局依赖,使得元学习器能注意到更有用的特征,在样本匮乏的情况下出色地完成分类工作。

3 VAE-ATTN小样本图像分类架构

针对传统深度学习方法的局限性,VAE-ATTN提供了很好的解决方案。VAE-ATTN提出通过预训练VAE学习任务高级表征,混合使用注意力机制快速运用关键表征的方法,最大化从少量样本中获取的有效信息。

方法分为两个阶段,第一阶段为表征模块的预训练。算法使用深度生成模型VAE构建一个提供数据嵌入或特征表征的模型。预训练集由大规模图像分类数据集ImageNet上随机抽取的150个类组成,这些类别和元数据集中的类别没有重叠。VAE从预训练集中学习各个图像类别共享的特征子集。特别地,实验尝试使用β-VAE作为表征模块,相比于线性嵌入或从常规变分自编码器获得的特征,β-VAE能够提取解缠的特征,具有更加有效的表征能力。

第二阶段为元学习阶段。将预训练完成的VAE编码器,作为特征提取器迁移至新的识别任务中。VAE输出的通道响应彼此关联,每个通道映射可以被视作特定于类别的响应。因此对VAE的输出特征引入注意力机制,利用通道映射之间的相互依赖性,选择性地强调相互依赖的特征映射,并改进特定类别的特征表示。本文使用的基准元学习算法为模型无关的Reptile元学习方法,Reptile掌握任务分布规律,从特征空间和参数空间对元学习器进行联合优化。

图3为基于VAE和注意力机制的元学习图像分类架构。编码器是深度为4的卷积网络,解码器由4层反卷积构成。对编码器提取的特征输入注意力模块,进行特征加强。最后通过由全连接层和Softmax层组成的分类器,得到图像分类成果。这样的结构即保留了抽象的图像特征,又为在面临新任务的学习时保留了调整的余地。算法运行的伪代码如算法1所示。

图3 VAE-ATTN图像分类框架

算法1VAE-ATTN元学习算法

1 预训练VAE模型,重复步骤1)-步骤2)直至图像重构误差小于σ:

1) 从预训练集中采样n张图片P(0)~P(n-1);

2) 在每幅图像上执行随机梯度下降,优化网络编码器参数φ和解码器参数θ。

2 将预训练好的编码器的参数值φ固定,连接Attention模块。

3 Attention模块参数A在元数据集上通过Reptile算法进行训练以学会强调关键的特征图,步骤1)-步骤3)预定义的J次:

1) 从元数据集中采样n个任务τ(0)~τ(n-1);

2) 在每个任务τi上执行连续k步梯度下降,计算权值Wi=SGD(Lτi,k,A);

4 在测试集上验证模型,获得最终准确率。

Reptile[10]作为基准元学习算法,本质上是通过不断地采样不同类别的任务,在任务层面实现知识的泛化。算法的优化目标如下:

(8)

4 实 验

为了验证基于VAE和注意力机制的元学习方法的有效性,实验选取两个重要的基准数据集Mini-ImageNet和Omniglot进行实验,并将测试结果与其他元学习方法进行比较。Omniglot[16]是Lake等提出的语言文字数据集,该数据集包含50种文字,1 623类手写字符,每一类字符仅拥有20个样本,且这些样本均为不同的人绘制而成。Mini-ImageNet[7]数据集由DeepMind于2016年提出,是计算机视觉领域的重要基准数据集,它通过从ImageNet随机抽样100个类并为每个类选择600个样本创建而成。其中:训练集包含64个类别,共计38 400幅图像;测试集包含20个类别,共计12 000幅图像;验证集包含16个类,9 600张图像。

4.1 实验设置

预训练阶段:变分自编码器从原始的,未标记的预训练集数据中进行学习。从ImageNet中随机抽取150类,每类600张图片组成预训练集。预训练集没有与Mini-ImageNet数据集中的类别重叠。在β-VAE训练阶段,本文采用Adam优化器,固定学习率为0.001。编码器模型运用4层CNN卷积层,每层使用64个大小为3×3的卷积核,输出为100维的隐变量。损失函数一方面通过交叉熵来度量图片的重构误差,另一方面,通过KL散度来度量隐变量的分布和单位高斯分布的差异。根据损失函数的收敛特性,本文选取的批大小为32,以获得随机性避免陷入局部最优化。

元学习阶段:网络运用训练集中有标记的,训练集数据样本进行学习。在预训练阶段之后,β-VAE已经从预训练集中学习了低维的高级特征,元学习器只需要通过快速调整其注意力模块来学习如何适应新的学习任务。网络使用Reptile算法对注意力模块进行2万次的训练迭代,每次连续计算8步梯度下降来更新网络参数,详细超参设置见表1。

表1 元学习参数表

4.2 实验分析

实验考虑解决小样本分类中K-样本,N-类别[7]学习问题。对于K-样本,N-类别(K-shot,N-way)分类的每个任务,学习器训练N个相关类,每个类都有K个例子,首先从元数据集中采样N个类,为每个类选择K+1个样本。然后,将这些示例拆分为训练和测试集,其中训练集包含每个类的K个示例,测试集包含剩余样本。以5-样本,5-类别分类为例,实验中共抽取30个样例,使用其中25个样本5(图像)×5(类)训练学习器并使用剩余的示例来测试模型。

4.2.1β-VAE的重构分析

对于无监督学习阶段,实验考察了β参数对提取解缠特征的影响。实验发现β=8是对于最终学习器进行小样本分类的最合适的参数值,实验中大约一半的隐变量已经收敛到单位高斯先验。如图4所示,(a)为测试图片,(b)为β=8时的β-VAE重构图像。从图像重建的质量上分析,由于隐变量的维度受限,良好的解缠表征可能会导致模糊的重建[12]。但解缠表征例如旋转、大小、位置等有助于加速后期元学习阶段的学习,帮助注意力模块理解不同任务之间的共享特征,对提升小样本分类性能有更明显的成效。

(a) 测试图像

(b) β-VAE的输出(β=8)图4 测试图像与重构图像

4.2.2注意力影响可视化分析

该部分实验成果可视化了注意力机制给小样本分类带来的影响。实验使用t-SNE算法[20]将网络输出的特征值降维并投影至2维空间。图5是Mini-ImageNet实验中测试场景的特征可视化图,(a)为特征在进入注意力模块之前的前期特征,(b)为经过注意力机制增强之后的特征。为使图像表述更加清晰,t-SNE实验中共采样3种类别,每种类别200幅图像进行降维,图中的3种标记符号分别代表3个不同的类别。

(a) (b)图5 特征通过t-SNE投影至2维空间的可视化结果

可以看出,在经过注意力模块的特征改进之后,不同图像类别之间的分布差异更加明显,类内距离的标准差缩小,而类间距标准差增大。实验结果表明,注意力机制可以捕获高级特征里的关键特征,有助于元学习器更好地区分不同类别的图像。

4.2.3小样本图像分类结果

将VAE-ATTN元学习方法与现有元学习方法相比较,表2及表3展示了基础设置和直推设置的实验成果。在直推模式中,元学习器允许同时拥有标签训练样本和无标签测试样本,训练后的模型一次性对测试集中的所有样本进行分类,因此允许信息通过批量标准化在测试样本之间共享[9]。也就是说,测试样本的类标签预测过程会受到彼此的影响,不再是相互独立的。表2与表3中,Y表示运用了直推设置,N表示未运用直推设置。观察实验结果发现,使用直推设置的分类结果明显优于未使用该设置的结果。

表2 Mini-ImageNet 小样本分类结果 %

表3 Omniglot小样本分类结果 %

续表3 %

从表2中可以看出,在Mini-ImageNet上,本文提出的算法超过了当前性能优异的元学习算法,如MAML、MatchingNets、Meta-LSTM等。在5-样本,5-类别以及1-样本,5-类别的测试场景中分别获得72.5%和53.5%的准确率,显著超越原始Reptile算法的分类性能。由表3可知,在Omniglot数据集上,β-VAE在5-样本,20-类别以及1-样本,20-类别的测试场景中,取得了98.8%和96.5%的高分类准确率。实验结果说明基于表征学习和注意力机制的方法改善了元学习器,证明了VAE-ATTN算法的合理性。

图6是Mini-ImageNet中5-样本,5-类别的直推实验的分类准确率曲线图。可以看出,VAE-ATTN算法均超出Reptile基准元学习算法,且运用β-VAE进行预训练的分类效果也优于常规VAE的训练效果。这一结果说明β-VAE提取的解缠表征加速元学习器结构化地理解多样的任务,实现更高的小样本分类准确率。

图6 Mini-ImageNet实验分类准确率对比

5 结 语

小样本图像识别在人工智能领域是复杂且具有挑战性的研究方向,极具探索价值和意义。本文通过分析以往元学习方法存在的问题,提出结合表征学习和注意力机制的新元学习方法VAE-ATTN。算法运用β-VAE学习的高级的解缠表征,并通过注意力机制增强重要的信息并抑制冗余的信息,从而引导元学习器进行小样本学习。本文算法在Mini-ImageNet和Omniglot数据集上的小样本学习测试中均展现了良好的性能,表明了算法的有效性和可行性。

在后续工作中,我们将考虑更具泛化性的元学习方法,目标是提取可跨任务或远距离迁移的特征,使得小样本学习能根据更充分的先验知识进行新任务的快速学习。

免责声明

我们致力于保护作者版权,注重分享,被刊用文章因无法核实真实出处,未能及时与作者取得联系,或有版权异议的,请联系管理员,我们会立即处理! 部分文章是来自各大过期杂志,内容仅供学习参考,不准确地方联系删除处理!