©PaperWeekly 原创 · 作者 | 苏剑林

单位 | 追一科技

研究方向 | NLP、神经网络

当前,像 ChatGPT 之类的 LLM 可谓是“风靡全球”。有读者留意到,几乎所有 LLM 都还是用最初的 Multi-Head Scaled-Dot Attention,近年来大量的 Efficient 工作如线性 Attention、FLASH 等均未被采用。是它们版本效果太差,还是根本没有必要考虑效率?

其实答案笔者在《线性Transformer应该不是你要等的那个模型》已经分析过了,只有序列长度明显超过 hidden size 时,标准 Attention 才呈现出二次复杂度,在此之前它还是接近线性的,它的速度比很多 Efficient 改进都快,而像 GPT3 用到了上万的 hidden size,这意味着只要你的 LLM 不是面向数万长度的文本生成,那么用 Efficient 改进是没有必要的,很多时候速度没提上去,效果还降低了。 

那么,真有数万甚至数十万长度的序列处理需求时,我们又该用什么模型呢?近日,Google 的一篇论文《Resurrecting Recurrent Neural Networks for Long Sequences》重新优化了 RNN 模型,特别指出了 RNN 在处理超长序列场景下的优势。那么,RNN 能否再次辉煌?

论文标题:

Resurrecting Recurrent Neural Networks for Long Sequences

论文链接:

https://arxiv/abs/2303.06349

线性化

文章提出的 RNN 叫做 LRU(Linear Recurrent Unit,线性循环单元),它是既可以并行又可以串行的极简线性 RNN,训练和推断都具备高效的优势。LRU 跟 SSM(Structured State Model)[1]、RWKV [2] 等工作有颇多相似之处。

事实上,LRU 的出发点就是发现 SSM 在 LRA 上表现很好,于是想办法将原生的 RNN 也能在 LRA 表现良好,其结果就是 LRU。遗憾的是,原论文只在 LRA(Long Range Arena,一个测试远程依赖能力的榜单)上做了实验,本文最后则会补充一些自己在语言模型上的实验结果。 

原论文的介绍从 SSM 出发,并且花了不少篇幅描写 LRU 与 SSM 的关联。而在本文中,我们略过这些关联的描写,直接将 LRU 作为一个独立的 RNN 模型进行推演介绍。我们知道,最简单的 RNN 可以写为:

其中 ,f 是激活函数。一般情况下 之前、 之后都还有一个投影矩阵,但这里我们重点关注循环本身,因此就不把它显式写出来了。

传统的认知中,激活函数是非线性的,常见的选择有 等,特别是有工作表明带有 或 激活函数的单层 RNN 是图灵完备的,这就让人坚信非线性激活函数的必要性。然而,在深度学习中,实验才是检验真理的唯一标准,作者发现,如果将 Transformer 的 Self Attention 替换为 RNN 的话,线性 RNN 效果才是最好的:

▲ 在 LRA 的各个任务上,线性 RNN 反而是最好的

这是一个让人意外的好消息。“意外”是因为可能会颠覆某些读者关于模型对非线性需求的认知;当然有些读者可能也不意外,因为 MetaFormer [3] 等工作也表明过,得益于 FFN 层的强大,Self Attention 等负责混合 token 的层的非线性可以很弱,甚至 Pooling 层都行。至于“好消息”,则是因为线性 RNN 有并行的实现算法,计算速度会大大快于非线性 RNN。

于是,作者围绕线性 RNN,进行了一系列探讨。

对角化

去掉激活函数,RNN 就再次简化为:

反复迭代得到:

可以看到,主要的计算量集中在矩阵 A 的幂运算上。这时候不难联想到矩阵对角化,它是计算矩阵幂的高效方法,然而一般的矩阵在实数域不一定能对角化。这时候我们该怎么办?考虑若当标准型?不,格局打开点,既然实数域做不了,我们到复数域去!几乎所有矩阵都可以在复数域对角化,这意味着 A 总能写成:

其中 , 是特征值组成的对角阵。代入式(3)我们得到:

刚才我们说了,一般情况下 之前、 之后都还有一个投影矩阵,只要我们约定这两个投影矩阵都是复数矩阵,那么理论上 、 都可以合并到它们的投影运算中,这就意味着,如果一切运算都在复数域中考虑,那么将线性 RNN 中的一般矩阵 A 换成对角阵 ,模型能力不会有任何损失!所以我们只需考虑如下的极简 RNN:

参数化

对角矩阵的好处是一切运算都是 element-wise 的,所以每个维度的运算可以充分并行,同时也意味着只要分析一个维度就相当于分析了所有维度,模型的分析只需要在一维空间进行。不妨设 , 代表 中的一个,同时在不至于混淆的情况下,、 同样也用来表示 在它们之中对应的分量,于是(6)简化为标量运算:

注意别忘了, 是复数,所以我们可以设 ,其中 都是实数:

求和过程中 t-k 都是负数,因此 ,要不然历史项的权重将会逐渐趋于无穷大,这跟直觉不符(直觉上对历史信息的依赖应该是逐步减弱的),也会梯度爆炸的风险;另一方面,如果 ,那么就会有梯度消失的风险。这就对 r 提出了两个要求:1)保证 ;2)初始化阶段 r 应该尽量接近 1。

为此,我们先设 ,那么 就要求 ,于是我们再设 ,这时候就有 而转化为无约束优化了。这里的 是另一个变量的记号,并非代表什么特殊的运算。而既然 被参数化为了 ,那么为了保持一致性,我们也将 参数化为 。

可能读者要问,约束 的方法有很多呀,为什么要搞这么复杂?直接加 sigmoid 不好吗?首先,将 r 参数化为 后,幂运算可以跟 的结合在一起,即 ,这样不管从实现角度还是计算角度都比较好;接着,因为 ,能将任何实数能映射为非负数的最简单的光滑函数,可能就是指数函数的,于是容易想到 。SSM 中采用的 激活,即直接 ,但这会有个饱和区,可能不利于优化。

初始化

接下来考虑初始化问题。我们回到原始形式(2),一个 的实矩阵,标准的 Glorot 初始化是均值为 0、方差为 1/d 的正态分布或者均匀分布(参考《从几何视角来理解模型参数的初始化策略》[4])。可以从理论或者实验上表明,这样的初始化矩阵,其特征值大致上均匀分布在复平面上的单位圆内:

▲ Glorot 初始化的矩阵的特征值均匀分布在单位圆盘内

由此,我们可以想到 的标准初始化方式是在复平面上的单位圆内均匀取点。而从笛卡尔坐标换到极坐标,我们有 ,这就告诉我们,要实现单位圆内均匀取点,只需要 以及 。

然而,刚才我们说为了尽可能地预防梯度消失,我们至少要在初始化阶段让 r 尽量接近于 1,所以改进方式是改在 的圆环内均匀采样,这样采样方式就变为 以及 。原论文的实验结果显示, 对多数实验都有较好效果。

▲ 改为圆环初始化,大部分任务的效果更好

这里有一个问题,就是 r 初始化接近 1,而初始阶段 也比较接近独立同分布的,那么式(8)就接近若干个模长不变的求和(而不是平均),这就可能有爆炸风险。为了分析这一点,我们先写出:

这里的 * 是复数的共轭运算, 是复数的模。接着两端求期望,这里我们假设 独立地服从同一均值为 0 的分布,那么当 时,,于是只剩下 的项非零,于是:

由于 ,当 t 足够大时 。这也就是说,当 t 比较大时,平均意义下 的模长与 的模长之比为 ,当 r 很接近 1 时,这个比例很大,也就是序列经过 RNN 后会膨胀得比较大,这不利于训练的稳定性。于是作者想了个简单的技巧,多引入一个 element-wise 的参数 ,初始化为 ,然后将式(7)改为:

这样一来,至少在初始阶段模型的输出就稳定了,剩下就让模型自己学就好了。综合以上结果,就是原论文所提的 LRU(Linear Recurrent Unit)模型了,如下图:

▲ LRU 模型示意图

相关化

这里介绍 LRU 的两个相关变体。

SLRU

LRU 的出发点是对一般的线性 RNN 模型(2)进行简化,而为了在理论上达到一般矩阵的效果,就不得不引入复的投影矩阵,以及复的特征值对角阵 。如果我们不考虑达到一般矩阵的效果,纯粹关心 r 所带来的衰减作用,那么我们可以进一步简化 LRU 模型——假设投影矩阵和特征值对角阵都是实数——这个简化版我们称为 SLRU(Simpler Linear Recurrent Unit)。

原论文并没有研究 SLRU,但笔者感觉它更符合我们的直觉(主要是相位 的变化不容易从直觉上理解),所以在后面也补充了 SLRU 的实验。

RWKV

谈到 RNN,可能有读者听说过最近小有名气的 RWKV [2],它可以看作 SLRU/Hydra Attention [5] 和GLU(Gated Linear Unit)的结合。RWKV 的 RNN 部分为:

可以看到,递归部分就是两个 SLRU,RWKV 的特点是两个 SLRU 的结果相除,起到归一化的效果,所以就不需要 LRU 中的 gamma 技巧了。另外也许是为了跟 Self Attention 对齐参数量,或者是为了进一步提升效果,在归一化之后 RWKV 再添加了一个门 与之相乘。虽然作者在 LM 任务上已经验证过了 RWKV 的有效性,但它与常见模型的对照实验似乎没有出现过,本文也将补充这部分。

注:这里的 RWKV 特指负责 token 混合的 RNN 模块,并非指作者给出的完整模型(即没有用作者的 Channel-Mix 层、Time Shift 等内容)。

代码化

这一节我们来讨论 LRU 的实现问题。原论文附录中给出了 Jax 版本的 LRU 参考代码,这里笔者也给出 Keras 版本的:

Github:

https://github/bojone/rnn

实现 LRU 有两个技术难点:复数化并行化

复数化

LRU 的投影矩阵和特征值都是复的,作者给出的 Jax 版代码是直接使用复数矩阵的,换到 Keras 这意味着我们无法用回已有的 Dense 层,这未免有些遗憾。事实上,根据 我们可以看出,复数投影矩阵只不过是将投影维度增加一倍而已,所以投影部分我们就不用复数矩阵了,直接用两倍 units 的 Dense 层就行。

接着是 部分,这既可以直接展开为纯实数运算,也可以直接按照公式用复数运算。如果展开为实数运算的话,其形式跟 RoPE 是一样的,所以笔者刚开始看到 LRU 时就很激动,以为这不就是“RoPE is all you need”哈。不过笔者对比过速度,发现直接按照公式实现的复数版速度会稍快一些,所以建议还是用复数版的。

最后,就是复数输出投影回实矩阵问题,根据 ,这意味着我们只需要将实部和虚部拼接起来,然后接一个 Dense 层就能实现了。

并行化

如果直接按照递归公式实现串行版的 RNN,那么训练速度将会非常慢(预测都是串行的自回归,所以预测没问题)。前面说了,线性 RNN 的一个重要特性是它本身有并行算法,可以大大加快训练速度。

事实上,我们可以将(7)改写为:

这其实已经告诉了我们一种快速的算法:每个 都乘以 ,这是 element-wise 的,可以并行;然后 这一步实际上就是 cumsum 运算,各个框架自带的实现都很快;最后就是 cumsum 的结果都乘以各自的 ,这一步也是 element-wise 的,可以并行。

然而,因为 ,所以当 k 很大时 几乎必定会爆炸,别说 fp16 精度了,在长序列时 FP32 甚至 FP64 都不一定能兜住。因此,这个看上去很简明的方案,理论上没有问题,实际上却没什么价值。

并行加速的关键,是留意到分解(T > t):

这个分解告诉我们,对整个序列做(7)的结果,等价于将序列分为两半各自做(7),然后将前一半的最后一个结果加权到后一半各个位置上,如下图左:

这里的关键是“分开两半各自做(7)”这两半是可以并行的!于是递归下去,我们就将原本是 的循环步数改为了 ,从而大大加快训练速度,如上图右。

事实上,这就是 Prefix Sum [6] 问题的“Upper/Lower”并行算法,代码细节可以参考笔者上面给出的代码。因为 Tensorflow 1.x 不支持直接写递归,笔者是用 tf.while_loop 或者 for 从下到上实现的,训练时只能勉强接近 Self Attention 的速度。

事实上如果将循环部分重写为 CUDA 内核的话,应该是可以超过 Self Attention 速度的(可惜笔者不会)。RWKV 的作者只是将 RWKV 的 RNN 格式写成了 CUDA 内核,没有考虑并行化,但就这已经可以媲美 Self Attention 的速度了。

此外,Prefix Sum 还有“Odd/Even”并行算法,理论上它的计算效率更高一些,但它的结构更复杂些,如果用 tensorflow 实现的话,它涉及到更多的循环步数以及更多的 reshape 和 concat 操作,实际效率未必比得上“Upper/Lower”并行算法,因此笔者就没有实现它了(主要还是 tensorflow 1.x 不支持递归导致的,如果用递归写倒不是太复杂)。

效果化

这一节我们将演示原论文在 LRA 上的实验结果,以及笔者在语言模型(LM)任务上的实验结果。

原论文中,作者主要是通过理论和实验相结合的方式,演示了如何一步步地优化普通的 RNN,直到在 LRA 上取得接近 SOTA 的效果,这个分析和改进的过程可谓是引人入胜,值得反复品味。但由于原论文的实验都是在 LRA 上反复进行的,所以实验本身并无过多精彩之处,这里只演示论文中的 Table 8:

▲ LRU 论文的实验结果汇总

对于本文的读者来说,可能更关心它在 NLP 尤其是近来很火的 LM 上的效果,可惜原论文没有这部分内容,笔者自己做了一些对比实验,供大家参考。对比的模型包括 GAU(同 GAU-α)、SA(同 RoFormerV2)、LRU、SLRU 和 RWKV,其中 LRU、SLRU、RWKV 都只是将 RoFormerV2 中的 Self Attention 换成参数量和计算量相似的 LRU、SLRU、RWKV。模型参数量均为 1 亿左右的 base 版,在当前算是小模型了,初始化均使用 DeepNorm,优化器用的是 Tiger,其他所有超参数都一致,基本上做到了比较严格的控制变量。

可以看到,从效果上排序,应该是:

从实验结果上我们可以得出:

1. LRU 优于 SLRU,表明引入复投影矩阵和复特征值确实是有帮助的,但计算效率会有一定损失(哪怕保持参数量不变); 

2. 当序列长度增加时,Attention 系列(GAU、SA)的效果会变好,而 RNN 系列(LRU、SLRU、RWKV)的效果则会下降,这是两者的本质差异,原因应该是 RNN 的长程记忆能力受限于 hidden_size; 

3. RWKV 确实有可能是目前最好的 RNN 模型,但跟 Attention 类(GAU、SA)模型还有明显的差距; 

4. 根据第 2 点,RNN 系列需要追平 Attention 系列,那么应该需要继续放大 hidden_size,所以在 LM 任务上 RNN 系列或许需要更大尺度才有优势; 

5. 结合第 1 点和第 3 点,下一个改进版的 RNN 是否就是复数版 RWKV 了?

此外,还有几点实验过程中的经验。由于 GAU 是单头的,因此在长序列、大尺度的场景下它的计算效率明显优于 SA,并且它的效果也优于 SA,所以 GAU 应该是在相当大的一个范围内是语言模型的最佳选择,拍脑袋想的话,百亿参数以内、序列长度 5000 以内,都建议优先考虑 GAU。

但不可否认,同尺度的 RNN 系列模型在推理效率上更优(每步递归的计算量和 cache 大小都一致),而训练效率上也不输于 Attention 系列,因此模型放大之后,应该还是有机会跟 Attention 系列一较高低的。

特别要指出的是,RWKV 虽然整体表现不错,但它是笔者实验的唯一一个会出现 spike 现象的模型(如图的 Loss 曲线突然飙升,试过重跑也有出现,不是偶然),所以公平比较之下,RWKV 也没有传说中那么完美无暇,它也需要很多 trick。

事实上,RWKV 作者自己的实现中,就包含了一系列据说有助于训练 LM 但相当晦涩的 trick(按照作者的意思,他这些 trick 才是“精华”),这些 trick 需要读作者给的源代码才能发现,它们没有考虑进笔者的实验中。不排除这些 trick 有助于更好训练一个 LM 的可能性,但笔者更多的是想做一个公平的对照实验而非实际训练一个 LM 模型,一旦引入这些 trick,变量就太多了,笔者算力有限,无法一一对照。

当然,以上结论都只是在 1 亿级别的“小模型”中得出的,更大尺度的模型笔者还在尝试中,暂时没法给大家结论。

结论化 

本文介绍了 Google “拯救” RNN 的一次尝试,自上而下地构建了一个在 LRA 上表现接近 SOTA 的高效 RNN 模型。除了原论文在 LRA 上的实验外,本文还给出了笔者自己在语言模型上的实验结果,包括与 RWKV 等相关模型的对比。总的来说,经过优化的 RNN 模型在训练效率上并不逊色于 Attention 类模型,同时有着更好的推理性能,但语言模型效果上离 Attention 类模型还有一定差距,也许需要将模型做得更大,才能进一步体现出 RNN 的优势。

参考文献

[1] https://arxiv/abs/2111.00396

[2] https://github/BlinkDL/RWKV-LM

[3] https://arxiv/abs/2111.11418

[4] https://kexue.fm/archives/7180

[5] https://arxiv/abs/2209.07484

[6] https://en.wikipedia/wiki/Prefix_sum

更多阅读

#投 稿 通 道#

 让你的文字被更多人看到 

如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。

总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 

PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。

📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算

📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿

△长按添加PaperWeekly小编

🔍

现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧

·

·

更多推荐

Google新作试图“复活”RNN:RNN能否再次辉煌?