从内存墙到高效计算|如何提升大模型在推理中的计算速度:MQA、GQA优化策略
Editor's Note
随着大型语言模型的不断发展,其在实际应用中的运算效率成为一个不可忽视的问题。本文以通俗易懂的方式,系统地介绍了Transformer在文本生成任务中计算效率低下的根源,以及所提出的MQA和GQA等解决方案。
The following article is from 安迪的写作间 Author 安迪的写作间
好几周前读完 GQA 论文就想写的,但一直拖着。直到最近 LLAMA2,还有 NV 小伙伴问到我这一点,我解释了下,他让我顺便写下来写成一篇文章吧。
我想着推导一下 MQA 就好了,但发现这个话题并没那么简单,这是由一系列限制导致的优化,涉及到模型结构甚至涉及到硬件,我个人最开始也没看懂 MQA 操作。于是一写就各种瞎扯了。
MQA,全称 Multi Query Attention, 而 GQA 则是前段时间 Google 提出的 MQA 变种,全称 Group-Query Attention.
MQA 提出时间挺早的,是 Noam Shazeer 这位谷歌老炮 19 年提出的。而 Noam 也是 Transformer 结构提出者之一,现在也就理所当然地早就不在 Google,是 Character.ai 的合伙人。
这位老哥对 Transformer 提出过好几种结构改进,比如 Talking-Heads、GLU 激活、还有这里谈的 MQA,他这些论文都简单粗暴,上方法看效果,不玩虚的,主打一个好用。
MQA 提出时,并没获得太大关注,包括我也是翻 Talking-Heads 才知道这篇论文,当时看了下也没觉得哪里好,就忘了。而且作者可能也没太当一回事,论文从头到脚都能看到两个字,随意。
最近才越来越多被提到,包括 Falcon,Star Coder,还有最近的 Llama 2 都有用到。
心里疑惑,为什么之前就没有被关注呢?
GPT 的阿克琉斯之踵
GPT,也就是 Transformer Decoder 结构做文本生成时有一个致命问题。
先来看看 Encoder 推理是怎么做的,每个 timestep 都能看到所有 timestep ,推理时所有 timestep 一层层向后计算,一把过。
于是内存相关开销就是 , 而计算相关开销就是 ,其中 N 为序列长度。
而 Decoder 推理时,最大不同在于自回归结构,可以看到图中每个 timestep 的输出都是下一 timestep 的输入,所以无法像 Encoder 一样一次过,每次都要 attend 之前的所有 timestep.
同样计算一下开销,计算开销是 也就是 ,而内存开销则是 O().
大家用 ChatGPT 接口也会有类似感觉,Context 部分成本很低,也很快,因为它做的类似于 Encoder 的并行。主要成本在生成那块,速度较慢,但也已经是优化过后的了。
(注:使用 ChatGPT 开放接口时,需要构建一个API请求,其中包含一个称为"content"的部分。这个"content"部分指的是你希望将输入传递给ChatGPT的文本内容。它可以是一个问题、一段对话、一句话,或者任何你希望与模型进行交流的文本。)
下面就来讲讲优化方法。
(注:在Transformer模型中,Decoder是负责生成输出序列的部分,比如在机器翻译任务中,将源语言的输入翻译为目标语言的输出。Decoder由多个Decoder层组成,每个Decoder层中都包含了自注意力机制(self-attention)和注意力机制(attention),用来处理序列中的不同位置和上下文信息。
在Decoder的每个Decoder层中,输入序列会被分成一系列的时间步(timesteps),每个时间步都对应着输入序列中的一个位置。在处理序列数据时,时间步是指序列中的特定位置,类似于时间上的步骤,尤其在自然语言处理任务中,虽然没有真正的时间概念,但这个术语被广泛应用。
每个时间步都涉及到对应位置的词嵌入(word embedding)以及上一层Decoder的输出,以便进行自注意力和注意力计算。这些计算有助于模型在生成输出时捕捉不同位置的信息和上下文关系,以便更好地生成合适的下一个单词。
总结起来,在Transformer Decoder中,时间步是指输入序列中的不同位置,每个时间步都会在Decoder的各个层中进行自注意力和注意力计算,以利用序列中的不同位置信息来生成相应的输出序列。这种处理方式使得Transformer能够捕捉到序列数据中的长距离依赖关系,并生成准确的输出。)
KV Cache
Decoder 每次前向,当前 timestep 计算 Attention 要用到的部分,如之前 timestep 的 KV (Key 和 Value)值都计算过的,只是之前每次前向完后给计算结果都丢掉,只保留最后输出。
于是一个很自然的想法就是 Cache。这很像斐波那契递归函数,naive 版本,也会出现不断重复计算问题,加个 cache 瞬间提速。
每次前向完,给 KV 都保留下来,用于之后计算。
代码表示如下(关于 Self-Attention 不清楚可看我这篇 《Transformer 三部曲:RNN 的继承者》):
#q、k、v 当前 timestep 的 query,key,value
# K_prev,V_prev 之前所有 timestep 的 key 和 value
for _ in range(time_step):
...
K = torch.cat([K_prev, k], dim=-2) #[b, h, n, d]
V = torch.cat([V_prev, v], dim=-2) #[b, h, n, d]
logits = torch.einsum("bhd,bhnd->bhn", q, K)
weights = torch.softmax(logits/math.sqrt(d), dim=-1)
outs = torch.einsum("bhn,bhnd->bhd", weights, V)
...
K_prev, V_prev = K, V
于是 Decoder 就被优化成,计算开销变成了 O(),存储复杂度则是 O(),只给 K 和 V 不断保存在缓存中就行。问题解决了!
但残酷现实会立马跳出来给你一棒子,上面假设 K 和 V 能直接存在缓存中,模型规模小还好,一旦模型规模很大长度很长时,KV 根本就存不进缓存。
比如 Llama 7B 模型,hidden size 是 4096,那么每个 timestep 需缓存参数量为 4096*2*32=262144,假设半精度保存就是 512KB,1024 长度那就要 512MB. 而现在英伟达最好的卡 H100 的 SRAM 缓存大概是 50MB,而 A100 则是 40MB. 而 7B 模型都这样,175B 模型就更不用说了。
那为什么我们不直接做大 SRAM 内存呢,不就直接解决问题了吗,但是这样又会产生一个新问题 SRAM 太贵了,所以这条路现在是不太行的。
于是退一步,放不进缓存可以放 DRAM 上去,而 DRAM 内存也就是我们常说的 GPU 显存。
但 DRAM 读取到计算芯片和 SRAM 到计算芯片的速度,差了一个量级的,这会让计算芯片一直在等待。
现在我们遇到了当今芯片领域,冯诺依曼架构下最大的一个问题,也就是:Memory Wall(内存墙)。
冯诺依曼架构和 Memory Wall
冯诺依曼架构熟悉有计算机相关基础的,应该都稔熟于胸。输入,输出,计算单元,加上存储单元。
现在随着摩尔定律的见顶,虽然计算和内存的发展速度在变缓,但这并不是最大的问题,最大的问题是存储单元与计算单元间的交互。
冯诺依曼架构需要先从内存中调取数据,送入计算单元进行处理,但现在计算单元的速度是显著提升的,而从内存中读取数据的速度却没跟上,所以计算和内存这里就形成了一个瓶颈。因为短板效应,内存读取速度限制了整体速度。计算单元能很快将数据处理完,但新数据却还没到,于是就只能等待,造成利用率不高。这就是内存墙。
因为内存墙问题,现在 GPU,一张 A100 卡计算单元的利用率到四五十就不错了,用上各种技巧优化到 60% 已经很高了。而对于 H100 卡问题会更严重,因为它的计算速度相对 A100 提高了 6 倍,而内存读取带宽只增加了 1.6 倍,所以也要大量优化来提高利用率。
内存墙怎么越过呢?
硬件层面上,比如现在已在使用的 HBM(高速带宽内存)提高读取速度,或者更彻底些,抛弃冯诺依曼架构,改变计算单元从内存读数据的方式,不再以计算单元为中心,而以存储为中心,做成计算和存储一体的“存内计算”。
软件层面上的话,最近的很多优化,比如 Flash Attention,Paged Attention 都可以算。Flash Attention 就是减少了计算 Softmax 时从 DRAM 内存读取数据次数,从而提高了效率。
同样,MQA 也是一个软件层面上翻墙的一个方法。
MHA 到 MQA 到 GQA
MQA 的方法很简单,难的是看到这样的方法后,能立刻想到它为什么好。
一起看看 MQA 和 GQA 是怎么来的。
首先是原始的 MHA(Multi-Head Attention),QKV 三部分有相同数量的头,且一一对应。每次做 Attention,head1 的 QKV 就做好自己运算就可以,输出时各个头加起来就行。
而 MQA 则是,让 Q 仍然保持原来的头数,但 K 和 V 只有一个头,相当于所有的 Q 头共享一组 K 和 V 头,所以叫做 Multi-Query 了。实现改变了会不会影响效果呢?确实会影响但相对它能带来的收益,性能的些微降低是可以接受的。
能带来多大的收益呢,实验发现一般能提高 30%-40% 的吞吐。
收益主要就是由降低了 KV cache 带来的。实际上 MQA 运算量和 MHA 是差不多的,可理解为读取一组 KV 头之后,给所有 Q 头用,但因为之前提到的内存和计算的不对称,所以是有利的。
而 GQA 呢,是 MHA 和 MQA 的折衷方案,既不想损失性能太多,又想获得 MQA 带来的推理加速好处。具体思想是,不是所有 Q 头共享一组 KV,而是分组一定头数 Q 共享一组 KV,比如上面图片就是两组 Q 共享一组 KV。
LLAMA2 中给出了效果对比,可以看到相比起 MQA,GQA的指标看起来还是要好些的。
同时在推理上的加速还和 MQA 类似:
MQA 和 GQA 形式在推理加速方面,主要是通过两方面来完成:
降低了从内存中读取的数据量,所以也就减少了计算单元等待时间,提高了计算利用率; KV cache 变小了 head_num 倍,也就是显存中需要保存的 tensor 变小了,空出来空间就可以加大 batch size,从而又能提高利用率。
如果要用 MQA 和 GQA,可以是从头训练的时候就加上,也可以像 GQA 论文里面一样,用已有的开源模型,挑一些头取个 mean 用来初始化 MQA 或 GQA 继续训练一段时间。
下面是 MQA 推导过程,不感兴趣同学可跳过,感兴趣同学可推一下,理解更透彻。
MQA 的推导
正如在 memory wall 中提到的,现在内存读取相对计算速度太慢导致拖后腿。
那么定义一个变量,, M 是 Memory 表示内存开销,而 A 是 Arithmetic 表示计算开销。如果这个值大于1的话,就会出现很明显的 Memory Wall,而当这个值小于1很多时,表示拿到数据后马上能开动马力计算,内存墙问题就不存在了。因为估算还有各种没考虑因素问题,所以即使等于 1 也不代表就能打满计算单元。
那么先来看看 MHA 下推理时每一个 timestep 这个值的大小,主要参考 MQA 原论文的简化:
#三个投影矩阵分别为 P_q, P_k, P_v; 维度为 h(头数), a(隐层大小,等于hd), d(每个头大小)
#当前 timestep 输入为 x,维度为 b(batch大小), a
#K_prev, V_prev 为 KV cache的矩阵,维度为 b, h, m(之前的timestep数),d; m+1=n
q = torch.einsum('ba,had->bhd', x, P_q) #M:had+ba, A:ba^2
k = torch.einsum('ba,had->bhd', x, P_k) #M:had+ba, A:ba^2
v = torch.einsum('ba,had->bhd', x, P_v) #M:had+ba, A:ba^2
K = torch.cat([K_prev, k.unsqueeze(2)], dim=-2) #M:bhnd+bhd, A:0
V = torch.cat([V_prev, v.unsqueeze(2)], dim=-2) #M:bhnd+bhd, A:0
logits = torch.einsum("bhd,bhnd->bhn", q, K)#M:bhnd+bhd, A:bhnd
weights = torch.softmax(logits/math.sqrt(d), dim=-1)#M:bhn
outs = torch.einsum("bhn,bhnd->bhd", weights, V)#M:bhn+bhnd, A:bhnd
所以对于 M 来说是
对于 A 来说
假设隐层大小和 timestep 数接近,, 那么 A 就是 , 因此
可以看到要想让这个比例小,可以增大b,也就是增大 batch size,现在推理优化就会将用户的请求收集成 batch 推理,提高利用率。同时前面提到,MQA 可以降低显存使用扩大 batch size,所以能提高一定利用率。
根据假设 ,这个比例会接近 1,会导致一定 Memory Wall,如果 n 很长的话问题就更明显。
而 MQA 的情况下
#投影矩阵 P_k, P_v 维度变为 a(隐层大小,等于hd), d(每个头大小)
#K_prev, V_prev 为 KV cache的矩阵,维度为 b, m(之前的timestep数),d; m+1=n
q = torch.einsum('ba,had->bhd', x, P_q) #M:had+ba, A:ba^2
k = torch.einsum('ba,ad->bd', x, P_k) #M:ad+ba, A:bad
v = torch.einsum('ba,ad->bd', x, P_v) #M:ad+ba, A:bad
K = torch.cat([K_prev, k.unsqueeze(1)], dim=-2) #M:bnd+bd
V = torch.cat([V_prev, v.unsqueeze(1)], dim=-2) #M:bnd+bd
logits = torch.einsum("bhd,bnd->bhn", q, K)#M:bhd+bnd, A:bhnd
weights = torch.softmax(logits/math.sqrt(d), dim=-1)#M:bhn
outs = torch.einsum("bhn,bnd->bhd", weights, V)#M:bhn+bnd, A:bhnd
会发现 A 整体来说没有变,如之前说的只是共享了 KV, 计算量还是一样的 ,M 变化比较大
于是系数为
其中后面两项,d 一般比 h 要大,所以可以主要考虑 项。可看到之前占大头的 在分母加了个系数 h,这样就能降低 从而提高效率。
感兴趣的话,可自己推导一下 GQA 的情况,其中 的分母中会加入一个数 , 其中 为 group 数,如果 为 1 的情况那就和 MQA 一样了,这块开销主要就有 g 来调整了。
再见美好旧时光
看到这,大概也能明白为什么要用 MQA 了,以及为什么 MQA 最近才突然火起来。
主要就是因为大规模 GPT 式生成模型的落地需求导致的。
而在以前根本不需要关心这些,LSTM 只用维护一个状态,不存在要保留 Cache 什么。
到了 Transformer 提出后,虽然最早 Transformer 提出时是用在 Seq2Seq 任务上,也就是 Encoder 和 Decoder 都用,但可能模型量级不大,也没有太多落地需求,所以没引起太大关注。之后火了两年的 BERT 又是 Encoder 结构,直接前向一把过。
也只有到最近 GPT 大模型得到广泛应用时,才发现推理的这个瓶颈,于是大家翻出几年前的 trick,应用起来,发现非常好用。
同样原因,GPT 推理加速这块最近引起很多关注,大家都在想各种方法来提高推理效率。Huggingface 这两天也给 text-generation-inference 库的 license 给改了,应该也是想用这个挣点钱。
Reference
GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints Multi-Query Attention is All You Need Fast Transformer Decoding: One Write-Head is All You Need 芯片简史 How Nvidia’s CUDA Monopoly In Machine Learning Is Breaking - OpenAI Triton And PyTorch 2.0
考虑到微信的便捷性,最近搭建了一个接入了 ChatGPT 的微信机器人,并建了一个AI助手体验群,邀请朋友们进群体验,可以问智能助手各种问题。
也欢迎朋友们分享,和邀请更多朋友进群体验。
点下方二维码,长按扫码进群。
如果群二维码过期,关注公众号联系我,拉你进群。