其他
黄药子苦豆子川楝子,给我中药识别AI一点面子
本文作者:韩爱庆 北京中医药大学管理学院 副教授
中药是中医临床治疗的主要载体、也是中医药文化最核心的载体,蕴含着大量的科技资源、文化资源、产业资源。由于中药材的真伪、优劣与临床用药的安全有效有直接关联,因此对中药材进行客观、科学的鉴定识别,有利于规范中药材市场,加强质量监控,促进中药材产业的发展。 很多学者对应用客观指标来进行中药材的鉴定识别进行了广泛的研究,相关研究所采用的方法包括近红外光谱、指纹图谱、电子鼻以及化学模式识别,或者采用多方法相结合来开展研究。
这些方法的共同特征是需要基于专家经验或先验知识来人工设计或提取特征,前期的特征提取工作量特别大,得到的模型在训练集上准确率虽然高,但在测试集上普遍存在泛化能力差的情况。
另外使用这些方法来进行中药识别需要专用设备,便利性较差,应用门槛高,不利于推广应用。因此,开发一款能够自动提取特征,并能够使用通用终端设备进行识别,准确率高,泛化性能好的中药识别系统是行业的迫切需求。 近年来,以深度学习为代表的人工智能技术飞速发展。由于本轮人工智能可落地性非常强,可快速为行业应用赋能,所以在工业、商业、金融等各领域亦备受追捧,目前正快速应被推广应用到各个领域。
在此背景下,世界知名公司纷纷推出深度学习框架,比如Google推出TensorFlow,Facebook推出PyTorch。国内有飞桨,这是一个集深度学习核心框架、基础模型库、端到端开发套件、工具组件和服务平台于一体的开源深度学习平台。
基于飞桨模型,并借助百度AI Studio开发平台以及平台提供的Tesla V100 GPU算力,我们开发了基于深度学习的中药材识别模型,并完成了微信小程序开发和部署。在下文中,我将为大家解析此过程。
方案解析
数据采集与预处理
应用留出法,将80%的样本设置为训练集,20%的样本设置为测试集。为了增加训练集的数据量,提高模型的泛化能力,对训练集进行数据增强处理。
应用数据增强技术,对已有图片做缩放、随机旋转、随机裁剪、对比度调整、色调调整以及饱和度调整,使得总训练样本量达到213140张,数据增强后,大幅提升了训练样本数量。
配置网络
配置网络包括三个部分:网络模型、损失函数及优化函数。
本研究采用的网络模型为ResNeXt50卷积神经网络。目前主流深度学习分类模型包括LeNet、AlexNet、VGG、GoogleNet、Inception、ResNet、ResNext等。
其中,ResNet是2015年ILSVRC的冠军,而ResNeXt模型是ResNet模型的升级版,是2016年ILSVRC的亚军。
ResNeXt同时采用 VGG 堆叠的思想和 Inception 的 split-transform-merge 思想,以一种简单可扩展的方式延续split-transform-merge策略。
整个网络的buildingblock都是一样的,不用在每个stage里对每个buildingblock的超参数进行调整,只用一个结构相同的buildingblock,重复堆叠即可形成整个网络。
模型的可扩展性比较强,可以认为是在增加准确率的同时基本不改变或降低模型的复杂度。以下为ResNet(左图)与ResNeXt(右图)基本block对比。
def bottleneck_block(self,input,num_filters,stride,cardinality,reduction_ratio,name=None):
conv0 = self.conv_bn_layer(
input=input,
num_filters=num_filters,
filter_size=1,
act='relu',
name='conv' + name + '_x1')
conv1 = self.conv_bn_layer(
input=conv0,
num_filters=num_filters,
filter_size=3,
stride=stride,
groups=cardinality,
act='relu',
name='conv' + name + '_x2')
conv2 = self.conv_bn_layer(
input=conv1,
num_filters=num_filters * 2,
filter_size=1,
act=None,
name='conv' + name + '_x3')
scale = self.squeeze_excitation(
input=conv2,
num_channels=num_filters * 2,
reduction_ratio=reduction_ratio,
name='fc' + name)
short = self.shortcut(input, num_filters * 2, stride, name=name)
return fluid.layers.elementwise_add(x=short, y=scale, act='relu')
ResNeXt-50的整体网络结构如下,主要是将ResNet单元换成了ResNeXt单元。训练网络
针对个人和机构AI研究者普遍缺乏算力的现状,AI Studio平台免费提供基础版(CPU:2 Cores RAM:8GB,Disk:100GB)和高级版(GPU:Tesla V100,Video Mem:16GB;CPU:8Cores,RAM:32GB, Disk:100GB)两种运行环境。
由于本项目数据量较大,模型训练过程选用GPU高级版运行环境。 训练分为三步:
第一步配置好GPU训练环境;
第二步用训练集进行训练;
第三步保存好训练的模型。
第一步,定义GPU计算场所,创建一个executor,对program进行参数初始化。
第二步,设置好训练的轮数,用训练集进行训练。遍历batch_reader迭代器,喂入一个批次的数据。
为方便后续分析和过程可视化,为每个pass的每个批次数据加上索引step_id,每喂入500个batch,保存一次Pass_Num,trainbatch_Num,Train_loss,Train_acc1和time,并使用print语句输出训练的中间结果,随着训练的进行,损失率逐渐下降,准确率逐渐提高,模型逐渐优化。
第三步,模型保存。由于数据量较大,需要训练几十个小时。为防止训练过程意外中断,在训练过程中,每喂入500个批次的数据保存一次中间模型,一旦出现意外中断,下次训练直接导入中间模型继续训练,不需重新开始。
最终训练完成时保存最终训练模型,为预测模型做准备。
模型预测并设计微信小程序
预测程序为独立代码模块,可独立运行。预测主要分为四步: 第一步:配置预测环境;第二步:预处理预测图片。将非RGB图片进行模式转换,转为RGB模式;对预测图片进行裁剪和缩放,调整大小为[3, 224, 224];第三步:加载预测模型并将预测图像放入模型进行预测;第四步:输出预测结果,确定结果所属类别。
本研究共预测图片6766张,预测准确率predict accuracy=94%,部分图片预测结果如下图所示。
目前小程序已收录了257味常见中药材的药性和作用,如下图(下)所示。