当前位置:首页 期刊杂志

基于GAN网络的菌菇表型数据生成研究

时间:2024-05-24

袁培森 吴茂盛 翟肇裕 杨承林 徐焕良

(1.南京农业大学信息科学技术学院, 南京 210095; 2.马德里理工大学技术工程和电信系统高级学院, 马德里 28040)

0 引言

表型(Phenotype)研究核心是获取高质量的性状数据,进而对基因型和环境互作效应(Genotype-by-Environment) 进行分析[1-2],表型组学近年来发展迅猛,已成为分子育种和农业应用中的重要技术支撑[3-4]。然而,植物表型数据的获取需搭建实验环境,并需昂贵的数据采集工具,具有周期长、代价高昂等特点[1,5-6]。当前,以大数据为基础的深度学习正在成为表型数据分析的有力工具[7-8],深度学习相关算法的有效性在很大程度上取决于标记样本的数量,因此限制了其在小样本量环境中的应用[9]。数据的非均衡性是生物表型数据具有挑战性的问题[10-13]。

为了提升非均衡数据分析的性能和质量,文献[14-15]提出了数据生成的方法。然而,过采样技术SMOTE[15]、ADASYN[16]等对于处理经典学习系统中的类不平衡有效,但是此类方法生成的数据不能直接应用于深度学习系统[17]。近年来,生成式对抗网络(Generative adversarial networks,GAN)[18]的出现为计算机视觉应用提供了新的技术和手段,GAN采用零和博弈与对抗训练的思想生成高质量的样本,具有比传统机器学习算法更强大的特征学习和特征表达能力[19],是一种基于深度学习的学习模型,可以用于海量数据的智能生成,已经广泛用于图像、文本、语音、语言等领域[20-21]。

有学者提出将GAN网络技术用于生物学等领域的数据生成问题[9,22-25],结果显示生成数据的质量有显著提高。目前,记录约8万种真菌、近1 500种野生蘑菇种类的图像数据集,这对种类繁多和分布非均衡的菌类识别和分类具有重要的生态意义[26-28]。

本文提出基于生成对抗网络的菌菇表型数据生成方法(Mushroom phenotypic based on generative adversarial network, MPGAN)。以菌菇表型为研究对象,在特定目标域上训练GAN网络,作为GAN发生器网络的输入给出潜在模型,以期生成可控制和高质量的蘑菇图像。

1 GAN网络原理及系统框架

1.1 GAN网络基本原理

GAN[18]的核心思想来源于博弈论的纳什均衡,它设定双方分别为生成器和判别器,生成器的目的是尽量学习真实的数据分布,而判别器的目的是尽量正确判别输入数据是来自真实数据还是来自生成器。GAN中的生成器和判别器需要不断优化,各自提高生成能力和判别能力,其学习优化过程就是寻找二者之间的一个纳什均衡[29]。

1.2 GAN系统框架

GAN系统一般框架如图1所示,系统结构主要包括:生成器(用于生成虚拟图像),它通过接收随机噪声z,通过这个噪声生成网络G(z)。判别器是负责判断图像真假,输入图像x,输出对该图像的判别结果D(x)。

图1 一般的GAN框架Fig.1 Framework of GAN

首先,在给定生成器G的情况下,最优化判别器D。采用基于Sigmoid的二分类模型的训练方式,判别器D的训练是最小化交叉熵的过程,其损失函数表示为

(1)

式中x——采样于真实数据分布Pdata(x)

z——采样于先验分布Pz(z),例如高斯噪声分布

E(·)——计算期望值

式(1)中判别器的训练数据集来源于真实数据集分布Pdata(x)(标注为1) 和生成器数据分布Pg(x)(标注为0)。

给定生成器G,最小化式(1)得到最优解。对于任意的非零实数m和n,且实数值y∈[0,1],表达式为

Φ=-mlgy-nlg(1-y)

(2)

(3)

D(x)代表x来源于真实数据而非生成数据的概率。当输入数据采样自真实数据x时,D的目标是使得输出概率D(x)趋近于1,而当输入来自生成数据G(z)时,D的目标是正确判断数据来源,使得D(G(z))趋近于0,同时G的目标是使得其趋近于1。生成器G损失函数可表示为

OG(θG)=-OD(θD,θG)

(4)

其优化问题是一个极值问题,GAN的目标函数可以描述为

min(G)max(D){f(D,G)=Ex~Pdata(x)lgD(x)+Ez~Pz(z)lg(1-D(G(z)))}

(5)

GAN模型需要训练模型D最大化判别数据来源于真实数据或者伪数据分布G(z)的准确率,同时,需要训练模型G最小化lg(1-D(G(z)))。

GAN学习优化的方法为:先固定生成器G,优化判别器D,使得D的判别准确率最大化;然后固定判别器D,优化生成器G,使得D的判别准确率最小化。当且仅当Pdata=Pg时达到全局最优解。

2 MPGAN系统实现

2.1 MPGAN系统框架

MPGAN系统的框架如图2所示,蘑菇图像的生成过程为:生成器G(z)使用截断到一定范围内的随机正态分布数据作为输入,输入到卷积网络(Convolutional neural network, CNN),最后输出生成图像数据。判别器D(x)根据真实图像数据和生成图像数据输出判别结果,并对神经网络的所有参数进行反向更新操作。

图2 蘑菇表型数据生成的MPGAN框架Fig.2 MPGAN framework for mushroom phenotypic data generation

图3 生成器神经网络框架Fig.3 Neural network framework of generator

2.1.1生成器

生成器卷积神经网络结构的作用是通过输入随机数据生成128×128×3的图像,128表示像素数,3表示RGB的通道数。图3是生成器的框架。

生成器采用8层的卷积神经网络,首先是Input数据输入层,第2层是全连接层(Fully connected, FC),然后是连续5个反卷积层(Deconvolution, DeConv),其中分为DC反卷积层、BN批归一化层(Batch normalization,BN)和激活函数,批归一化层是对于同一批次数据按照给定的系数进行规范化处理,以防止梯度弥散,最后是Output数据输出层。生成器的反卷积层如图4所示,各层具体描述如下:

(1)FC全连接层设计输入为生成100个图像的随机数据,经过全连接层的8 192个神经元处理以及形状重塑后变为4×4×512大小的数据,再经过批归一化层及ReLU激活函数后将结果输出到下一层。

(2)生成器中包括5个反卷积层,卷积核的移动步长为2,卷积核尺寸为5×5,1~4层的每一层经过批归一化层及ReLU激活函数后将结果输出到下一层,其中:

第1层输入数据为4×4×512。反卷积层的卷积核数为256个,经过反卷积后得到的数据为8×8×256。

第2层输入数据为8×8×256。反卷积层的卷积核数为128个,经过反卷积后得到的数据为16×16×128。

第3层输入数据为16×16×128。反卷积层的卷积核数为64个,经过反卷积后得到的数据为32×32×64。

第4层输入数据为32×32×64。反卷积层的卷积核数为32个,经过反卷积后得到的数据为64×64×32。

图4 生成器的反卷积层Fig.4 Deconvolution layer of generator

第5层输入数据为64×64×32。反卷积层的卷积核数为3个。输入数据经过反卷积后得到的数据为128×128×3,再经过批归一化层及tanh激活函数后将结果输出到下一层。tanh函数表达式为

(6)

式中a——参数

不使用传统的Sigmod函数进行Output输出层,而是直接将上一层输入结果输出。生成器网络参数如表1所示。

表1 生成器网络参数Tab.1 Summary of generator network parameters

图5 判别器神经网络框架Fig.5 Neural network framework of discriminator

2.1.2判别器

判别器的作用是尽量拟合样本之间的Wasserstein距离,从而将分类任务转换成回归任务。判别器采用7层的卷积神经网络,首先是Input数据入层,接着是连续4个卷积层(Convolution,Conv),其中分为卷积层、归一化层和激活函数,然后是全连接层FC,最后是数据输出层Output。判别器的架构如图5所示。

判别器的Conv卷积层设计如图6所示。判别器共有4个卷积层,卷积核的移动步长为2,卷积核尺寸为5×5,经过归一化层及Leaky ReLU激活函数后将结果输出到下一层。

第1层输入数据为128×128×3。卷积层的卷积核数为64个,经过卷积后得到的数据为64×64×64。

第2层输入数据为64×64×64。卷积层的卷积核数为128个,经过卷积后得到的数据为32×32×128。

图6 判别器的卷积层操作Fig.6 Convolution layer of discriminator

第3层输入数据为32×32×128。卷积层的卷积核数为256个,经过卷积后得到的数据为16×16×256。

第4层输入数据为16×16×256。卷积层的卷积核数为512个,经过卷积后得到的数据为8×8×512。

FC全连接层设计的输入数据为8×8×512,经过全连接层处理以及形状重塑后变为大小为1的蘑菇图像,并将结果输出。判别器的网络参数如表2所示。

表2 判别器网络参数Tab.2 Summary of discriminator network parameters

2.2 网络优化设计

2.2.1Wasserstein距离

MPGAN系统采用带有梯度惩罚的Wasserstein距离[30],Wasserstein距离[9,31-32]又叫推土机(Earth-mover,EM)距离,定义为

(7)

式中Pr——真实数据分布

Pg——生成数据分布

r——真实样本

y——生成样本

γ——联合分布

∏(Pr,Pg)——Pr和Pg组合起来的所有可能的联合分布的集合

对于每个可能的联合分布γ而言,采样(x,y)~γ得到一个真实样本x和一个生成样本y,并计算这对样本之间的距离‖x-y‖,计算该联合分布γ下样本对距离的期望值E(x,y)~γ(‖x-y‖)。Wasserstein距离定义为在所有可能的联合分布中能够对这个期望值的下界[31]。

2.2.2系统损失函数

设定fw代表判别器网络,根据Lipschitz连续性条件的要求,该判别器网络含参数w,并且参数w不超过某个范围,根据式(7)定义的Wasserstein距离,MPGAN系统判别器的目的是近似拟合Wasserstein距离,因此判别器的损失函数可以表示为

LD=Ex~Pg(fw(x))-Ex~Pr(fw(x))

(8)

MPGAN系统生成器的目的是近似地最小化Wasserstein距离,即最小化式(8),因此生成器的损失函数可以表示为

LG=Ex~Pr(fw(x))-Ex~Pg(fw(x))

(9)

GULRAJANI等[30]提出的带有梯度惩罚的Wasserstein距离来满足Lipschitz连续性。当生成数据分布Pg接近真实数据分布Pr时,Lipschitz连续性可表示为

‖D(Pg)-D(Pr)‖≤K‖Pg-Pr‖

(10)

式(10)可转换为

(11)

式中Pc——生成数据分布与真实数据分布的差值

K——整数常量

先对真假样本的数据分布进行随机差值采样,即产生一对真假样本Xr和Xg,采样公式为

X=ξXr+(1-ξ)Xg

(12)

式中ξ——[0,1]区间的随机数

(13)

式中λ——调节梯度惩罚项大小的参数

K为使得Lipschitz连续性条件成立的常量,设定K为1,MPGAN系统的判别器损失函数式(9)和梯度惩罚项式(13),损失函数可表示为

(14)

2.3 MPGAN系统的训练过程

根据GAN网络的框架和优化过程,MPGAN系统的训练过程如图7所示。

图7 MPGAN系统的训练过程Fig.7 Training procedure of MPGAN system

图7中的训练过程描述如下:

(1)采用方差为0.02的截断正态分布初始化网络中的权值参数W和卷积核初始化网络的偏置值b,初始化学习率η,即每次参数更新幅度。在训练过程中,参数更新向着损失函数梯度下降的方向,表示为

Wn+1=Wn-ηΔ

(15)

式中Δ——梯度,即损失函数的导数

(2)采用区间为[-1,1]的均匀分布初始化随机噪声。

(3)采用数据集中随机获取批次大小的训练样本,并在输入队列中进行数据预处理。

(4)将步骤(2)中生成的随机噪声输入到生成器网络,生成虚拟图像数据,将生成的虚拟图像数据输入判别器,得到生成图像判别结果;将步骤(3)中获取的训练样本使用批归一化操作输入判别器,得到真实图像判别结果;计算判别器损失并反向更新判别器参数。

(5)计算梯度惩罚项,为判别器损失施加惩罚,然后使用优化器反向更新判别器参数,使用梯度惩罚项,替换原来的权重截断策略。

(6)判断是否达到指定判别器优化次数,即每优化一次生成器时优化N次判别器,若是则进入步骤(7),若否则重新进入步骤(3)。其中N由用户设定。

(7)将步骤(2)中生成的随机噪声输入到生成器网络,计算生成器损失并使用优化器反向更新判别器参数。

(8)判断是否达到指定迭代次数,即是否遍历完全部样本,若是则进入步骤(9),否则重新进入步骤(2)。

(9)判断是否达到EPOCH次数,EPOCH为总共训练的轮次,若是则结束,否则重新进入步骤(2)。

3 实验结果与分析

实验平台为Windows 10系统,16 GB内存,256 GB SSD,1 TB HD,Intel QuadCore i7-8700, 4.2 GHz, Nvidia GTX 1070,8 GB。算法采用Tensorflow V1.1 GPU框架[33]和Python 3.6实现。

3.1 数据集

采用两类数据集:开源蘑菇数据集Fungi[28],选择了其中375幅图像;私有数据集,共138幅图像。图像预处理方法包括随机翻转、随机亮度变换、随机对比度变换和图像归一化,前面几种预处理方法主要是为了增加样本数量,而图像归一化是为了降低几何变换带来的影响。

图8为开源数据集Fungi蘑菇示例图像,该数据集环境噪声大且背景复杂,背景中有草地、林地、树叶、木块等多种干扰物。

图8 开源数据集示例Fig.8 Examples of public dataset

私有蘑菇数据集采用凤尾菇作为对象,该数据集采用黑色作为背景,背景噪声小,且蘑菇形状不同,适合菌菇表型图像生成。图9为私有蘑菇数据集的示例图像。

图9 私有蘑菇数据集示例Fig.9 Examples of private dataset

3.2 参数设置

MPGAN系统默认使用Adam优化器[34],优化器超参数β1=0.5、β2=0.9、ε=1×10-8,学习率η默认为0.000 3,判别器优化次数N=5。

3.2.1生成器参数设置

由于生成器的输出层直接将前一层的值作为输入,最后激活函数选择tanh激活函数,该激活函数可以将输出层的输出约束到区间[-1,1]。

为了保证数据分布的一致性,并防止反向传播权值更新时发生梯度弥散并加速收敛,采用批归一化(Local response normalization),对同一批次数据按照给定的系数进行规范化处理。其处理步骤如下:

(1)沿通道计算同一批次内所有图像的均值μB,计算式为

(16)

(17)

(3)对图像做归一化处理,计算式为

(18)

ω——防止方差为0的参数

(4)加入缩放变量γ和平移变量φ,得出结果

yi=γi+φ≡BNγ,φ(xi)

(19)

式中yi——加入缩放变量γ和平移变量φ处理结果

3.2.2判别器参数设置

选择Leaky ReLU激活函数作为判别器激活函数,确保梯度更新整个图像。Leaky ReLU激活函数表达式为

(20)

式中α——(1,+∞)区间内的参数

MPGAN系统生成式对抗网络模型的梯度惩罚策略采用层归一化函数(Layer normalization,LN)。

3.3 Wasserstein距离与EPOCH

在学习率η为0.000 3时,使用开源数据集和私有数据集作为训练数据集,MPGAN系统的Wasserstein距离与EPOCH的关系如图10所示。

图10 Wasserstein距离收敛曲线Fig.10 Wasserstein distance convergence curves

由图10a可知,在开源数据集,EPOCH大于2 000后逐渐开始学习到真实图像的数据分布,在EPOCH达到10 000后逐渐趋于稳定,在这个阶段数据集本身噪声较大导致模型的学习能力有所下降,所以模型学习的特征被背景所干扰,并且在曲线尾部的振荡程度明显增大,此时减小学习率η可以使模型训练更加稳定。

由图10b可知,Wasserstein距离在EPOCH达到2 000后不断收敛,在10 000左右有小幅振荡,EPOCH在超过35 000之后,振荡幅度减小,模型比较稳定。

由图10可知,不同数据集训练的EPOCH次数不同,开源数据集的噪声较大,模型不容易收敛,并且相似度衡量指标Wasserstein距离在EPOCH为12 000时开始稳定在一个较高的程度;私有数据集上的噪声较小,当在该数据集,模型收敛更加快速,Wasserstein距离在EPOCH大于35 000时开始逐渐收敛稳定。

3.4 学习率与EPOCH

基于开源数据集的学习率与EPOCH关系如图11所示。从图11可看出,提高学习率η时,模型的收敛速度有明显的提升并在EPOCH为1 000后逐渐稳定,但是随着学习率的提高,收敛的振荡程度也在加大,因此可以在训练初期使用较大的学习率提高初始收敛速度,然后逐渐减小学习率保证训练过程稳定。由于在私有数据集上的结果类似,因此仅报告了开源数据集上的测试结果。

图11 基于开源数据集的学习率与EPOCH关系Fig.11 Learning rate and EPOCH relationship based on open source dataset

3.5 蘑菇图像生成

首先,系统测试了数据中的scalpturatum口蘑,EPOCH为1 000时,学习率η为0.000 1~0.000 5生成图像如图12所示。图12a为原始图像,从图12b可看出,学习率η为0.000 3时,生成的菌菇图像相对较好。

图12 不同学习率的菌菇图像生成结果对比Fig.12 Mushroom image generation results comparison at different learning rates

当学习率η为0.000 3时,在开源数据集和私有数据集上,测试了系统菌菇图像生成结果,生成图像尺寸设置为64像素×64像素,结果分别如图13和图14所示。图13为EPOCH为15 000时,开源数据集上的生成结果。图13b的生成图像能够清晰地显示出原始菌菇的表型特征。

图14为EPOCH为50 000时,私有数据集上的生成结果。图14b的生成图像能够清晰地显示出原始菌菇的表型特征。

图13 基于开源数据集上的蘑菇生成图像Fig.13 Illustration of generating Fungi images based on public dataset

图14 基于私有数据集上的蘑菇生成图像Fig.14 Illustration of generating Fungi images based on private dataset

对比图13b和图14b可以看出,图14b质量优于图13b,表明高质量的菌菇训练数据对图菌菇表型图像的生成有重要影响。

4 结论

(1)研究了菌菇表型数据生成技术,设计了用于菌菇表型数据生成的生成式对抗网络结构。使用Wasserstein距离和带有梯度惩罚的损失函数。

(2)利用开源数据和私有数据集进行了测试,结果表明,数据集噪声越小越好,噪声越小则损失越容易收敛,否则背景和主体目标发生混淆时,损失会在一个较大程度上振荡。

(3)测试了学习率η、EPOCH与Wasserstein距离关系,系统生成的菌菇表型数据可为后期菌菇数据分类与识别提供大数据基础,为解决菌菇分类的数据非均衡、长尾分布等问题提供研究基础。

免责声明

我们致力于保护作者版权,注重分享,被刊用文章因无法核实真实出处,未能及时与作者取得联系,或有版权异议的,请联系管理员,我们会立即处理! 部分文章是来自各大过期杂志,内容仅供学习参考,不准确地方联系删除处理!