Gboard 基于 RNN 的手写识别
文 / Sandro Feuz 和 Pedro Gonnet,手写团队高级软件工程师
2015 年,我们推出 Google 手写输入。该产品可作为任何 Android 应用的附加输入法,帮助用户在 Android 移动设备上手写文字。在首次发布的版本中,我们实现了对 82 种语言的支持,包括法语、盖尔语、汉语和马拉雅拉姆语等。为了提供更加流畅的用户体验和省去切换输入法的需求,去年我们为 Google 的移动设备键盘 Gboard for Android 添加了 针对 100 多种语言的手写识别支持。
此后,机器学习不断进步,催生出新的模型架构和训练方法,让我们得以修改最初的方法(即依靠手动设计的启发式方法将手写输入切割成单个字符),并构建单个机器学习模型。该模型不但能够应用于整体输入,而且与旧版相比,错误率显著下降。今年年初,我们为 Gboard 中所有基于拉丁字母的语言推出新模型,并发表论文《基于 LSTM 的多语言快速在线手写识别》("Fast Multi-language LSTM-based Online Handwriting Recognition"),更详细地介绍了为该版本进行的相关研究。在本文中,我们将简要概述这项研究。
接触点、贝塞尔曲线和递归神经网络
任何在线手写识别器都是从接触点入手。手写输入表现为一系列的笔画,而每个笔画又表现为一系列带时间戳的点。由于 Gboard 适用于各种设备和屏幕分辨率,因此我们首先要对接触点坐标进行归一化处理。然后,为了准确捕捉数据形状,我们将点序列转换为一系列三次方贝塞尔曲线,以作为递归神经网络 (RNN) 的输入值。RNN 经过训练,能够准确地识别手写字符(如需了解该步骤的更多详情,请参见下文)。虽然长久以来,我们一直在手写识别中使用贝塞尔曲线,但将其用作输入值还是很新鲜的做法,这使我们能够为不同的设备提供一致的输入表征(采样率和准确率并不相同)。该方法与之前使用所谓分段解码方法的模型截然不同,之前的方法是对如何将笔画分解成字符做出多个假设(分段),然后从此分解中找出最有可能的字符序列(解码)。
该方法的另一个优势是贝塞尔曲线序列比输入点的基础序列更加紧凑,这会使模型在输入的同时更容易获得临时依赖项。每条曲线都由起点和终点,以及其他两个控制点定义的多项式表示,以确定曲线形状。为了找出能够准确表示输入值的三次方贝塞尔曲线序列,我们采用迭代过程最小化归一化输入坐标与曲线之间的平方距离(以 x、y 和时间计算)。下图是曲线拟合过程的示例。用户的手写输入以黑色表示。该输入值由 186 个接触点组成,清晰指明 go 这个词。在黄色、蓝色、粉色和绿色部分,我们通过四条三次方贝塞尔曲线组成的序列看到字母 g 的表征(每条曲线有两个控制点),而相应地,橙色、蓝绿色和白色表示插入字母 o 的三条曲线。
字符解码
曲线序列表示输入值,但我们仍需要将输入曲线序列转化为实际的书面字符。为此,我们使用多层 RNN 来处理曲线序列,并生成输出解码矩阵。该矩阵提供每条输入曲线中所有可能字母的概率分布,以指示哪些书写字母是该曲线的一部分。
我们试验了多个类型的 RNN,最终决定使用双向版准递归神经网络 (QRNN)。QRNN 交替使用卷积层和递归层,为实现高效并行化提供理论上的可能性,并在保持权重数相对较少的同时提供良好的预测性能。权重数与需要下载的模型大小直接相关,所以权重数越少越好。
为了 “解码” 曲线,递归神经网络会生成一个矩阵,其中每一列对应一条输入曲线,每一行对应字母表中的一个字母。我们可以将表示特定曲线的列视作字母表中所有字母的概率分布。但每个字母可以由多条曲线构成(例如上面的 g 和 o,分别由四条和三条曲线构成)。递归神经网络输出序列的长度(与贝塞尔曲线的数量始终一致)与输入应该表示的实际字符数不一致。该问题的解决方法是添加特殊的空白符号,以表明特定曲线没有输出值,这与在联结主义时间分类 (CTC) 算法中的做法一样。我们使用有限状态机解码器将神经网络的输出值与编码为加权有限状态接收器的字符语言模型相结合。语言中常见的字符序列(例如德语中的 “sch”)获得奖励,并且更有可能成为输出值,而不常见的序列则会受到惩罚。此流程的可视化演示如下。
接触点序列(如上图所示,根据曲线段进行颜色编码)被转换为更短的贝塞尔系数序列(示例中为 7 个系数),每个系数对应一条曲线。基于 QRNN 的识别器将曲线序列转换为一系列长度相同的字符概率,并呈现在解码器阵列中,其中行对应 “a” 到 “z” 的字母,而空白符号(条目亮的地方)对应其相对概率。从左向右浏览解码器矩阵,我们发现大多数都是空白,另外还有字符 “g” 和 “o” 的亮点,因此文本输出为 “go”。
虽然新的字符识别模型比旧模型简单得多,但其错误率比旧模型低 20% 到 40%,而且速度也更快。但是,我们仍然需要在设备上执行这些操作!
在设备上运作
为提供最佳用户体验,识别模型只有高准确率还不够,还要有很快的速度。为尽可能降低 Gboard 中的延迟,我们将识别模型(在 TensorFlow 中训练)转换为 TensorFlow Lite 模型。这就需要在模型训练过程中量化所有权重,将每个权重使用四个字节减少为一个,进而缩小模型大小,并减少推理时间。此外,相比使用完整的 TensorFlow 实现,我们可以使用 TensorFlow Lite 缩减 APK 的大小,因为它专为小二进制文件进行优化,仅包括推理所需的部分。
后续计划
我们将继续挑战极限,不断改进拉丁字母的语言识别器。手写团队已经在努力开发新模型,以便为 Gboard 支持的所有手写语言提供助力。
致谢
感谢为提升 Gboard 手写体验做出贡献的每个人。我们要特别感谢来自 Gboard 团队的 Jatin Matani、语音和语言算法团队的 David Rybach、Expander 团队的 Prabhu Kaliamoorthi、TensorFlow Lite 团队的 Pete Warden,以及手写团队的 Henry Rowley、Li-Lun Wang、Mircea Trăichioiu、Philippe Gervais 和 Thomas Deselaers。
更多 AI 相关阅读: