- 余辉
-
本来想把关于CTC的所有东西都写在一篇文章,但后面发现内容太多,遂拆分成如下几个部分:
CTC算法详解之训练篇
CTC算法详解之测试篇
CTC算法详解之总结展望篇(待更)
在日常生活中,许多数据是序列化的,比如语音、文字和图像文本等。在处理序列任务时,一个经典的思路是“分而治之”,把输入序列拆分成最小语义单元,然后将序列任务转换成对单元的分类任务。然而在实际应用中,把序列中的单元精准地分开是很难的,人工标注的代价也很大,可不可以直接对序列数据进行“端到端”地预测呢?CTC(Connectionist Temporal Classification)则是解决了这样一个问题。CTC算法可以让以端到端地方式对序列数据进行学习,在语音识别、图像文字识别等领域取得了很好的应用效果。
本文先对CTC用于序列任务的流程做了大致介绍,并定义了相关的符号。然后再训练一节中介绍了如何通过前后向算法计算CTC的损失函数。yudonglee的博客[2]给了我很多帮助,训练一节中的很多图也是采用他的。我在写的过程中也是在不断学习,有错误和不到位的地方希望大家指出。
CTC用于序列任务,流程大致如下:神经网络把输入序列转换成序列在字典上的概率分布,从这个分布中我们可以得到若干条路径,每个路径都可以转换成输出序列,我们的任务就是找到输出概率最大的序列。具体的符号定义如下:
输入序列 长度为 ,用 表示神经网络,用于提取序列特征,网络的输出为 ,长度也为 ,用 表示输出 单元的激活概率 ,即序列在 时刻被分类成 的概率, 定义在类别集合 上, 为任务字典符号集,比如在文字识别任务中可以定义成中英文字符, 为CTC的 保留符号,用于分隔标签中的不同符号单元。CTC是一种 过分割 的序列解码算法,比如标签中的一个字符a可能在译码路径中被切分成多个连续的a,而标签中也可能存在连续但应该被区分的字符,比如apple中出现了两个p,那这时候译码路径要在这两个不同的p之间插入至少一个 。由网络的输出 可以计算任意译码路径 的概率 。 与输入 等长,我们最终是要得到标签 , 的长度小于 的长度,所以还要定义一个映射 来将译码路径 转换为标签 。映射规则为: 移除所有空白符号 并 合并所有的重复连续符号 ,比如 。可以看到,这个映射是多对一映射(many-to-one),也就是说正确的标签 可以来自许多不同的路径(不管是黑猫还是白猫,只要能捉到耗子就是好猫),后面我们会重点研究many-to-one所带来的计算速度和可微问题,尤其是在训练阶段。整个序列识别的流程可以参见下图:
为了使整个网络可用梯度下降优化,训练过程中必须算出可导的CTC损失函数,CTC也采用了常规的分类任务的最大似然误差(maximum likelihood error): 。因为B是many-to-one映射的缘故, ,计算Loss要穷举所有可行路径,然而穷举所有的路径是非常困难的,因为其空间复杂度为 (N为字典大小,T为路径长度),所以[1]借鉴了HMM中的前后向算法(Forward-Backward Algorithm,FBA),这是一种动态规划算法,下面我们来说一下算法思路。
在 种路径中,只有很少的一部分路径是有效的,我们只需要考虑这一小部分路径就行了。当我们把所有可行路路径列出来,会发现,如果按时间展开译码过程,我们可以以递推的方式计算出某个节点的前向(时间增大的方向)或后向(时间减小的方向)路径概率总和。这也是算法名称的由来。我们会先用一个”apple”的例子来直观解释FBA算法的递推关系,最后给出计算式。
给定一个标签 ,长度为 ,为了找出所有满足 的路径 ,我们要构建一个拓展标签 ,它是在原始标签的首尾和每个字符中间加上空格符号得到的,长度为 ,比如当 时 。我们接下来的搜索过程都在由 展开的搜索栅格上进行。
然而并不是在图3上的任意一条路径都是合法的,合法的路径要满足如下几点条件:
(1)转换只能向右或右下(纵轴上单调非减)
(2)相同的字符间至少有一个空格,否则标签中的连续相同字符会被错误地合并;
(3)除{blank}符号外不能跳过;
(4)路径起点必须从前两个符号开始,即 或 ;
(5)路径必须在最后两个符号结束。
最终所有可能的路径如下:
读者可以自行验证,在由 构成的搜索栅格上,遵守上述5条规则可以得到所有的正确的路径。我们并不需要穷举所有的 种路径就能计算出想要的结果,这就是动态规划的核心思想: 提前剔除掉不可能的结果,在更小的搜索空间上进行计算 。
如何计算图5路径的概率总和呢?我们首先定义 为t时刻取值为s的全部前缀路径概率总和:
累乘符号表示同一路径上的不同节点概率相乘,累加符号表示不同路径的概率相加。比如(t2,a)这个点的左边和左上角各有一条前向路径, 即为这两条路径的概率之和。
可以用递推方式求得,我们分三种情况讨论,在t时刻:
(1)s取值为 时, ,参见图6的红圈节点,有效的前向路径可来自左边或左上,左侧没有有效路径意味着 。
(2)s取值和s-2取值一样时, ,参见图6的蓝圈节点。
(3)其余情况下, ,参见图6的黑色圈。
所有情况可汇总如下:
初态:
最终的CTC损失函数为: 。因为整个计算过程涉及到的运算都是可微的,所以可以用链式求导计算导数,进行反向传播。类似的,也可以用反向路径概率和 来表示损失函数,读者可在[1]或[2]中找到相应内容。
参考资料
[1] Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks
[2] CTC Algorithm Explained Part 1:Training the Network(CTC算法详解之训练篇)
[3] Facebook大规模文本检测与识别系统Rosetta
[4] CTC Networks and Language Models: Prefix Beam Search Explained