查看原文
其他

使用PyTorch 2.0加速Transformer:训练推理均拿下!

AI小将 机器学习算法工程师 2023-12-17

点蓝色字关注“机器学习算法工程师

设为星标,干货直达!


目前Transformer已经成为各个领域(文本,图像,语音)最常用的模型架构,最近刚发布的PyTorch 2.0也进一步对Transformer模块进行了优化,以支持Tranformer结构模型的高效训练和推理。具体来说,PyTorch 2.0在torch.nn.functional中引入了一个新的函数:torch.nn.functional.scaled_dot_product_attention,这里简称为SPDA,这个SPDA背后实现了高性能的kernels,所以你可以直接使用SPDA来进行训练和推理的减速。

这里我们可以简单看一下这个SPDA这个函数的签名和参数说明:

torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
) → Tensor:
"""
Args:
query (Tensor): Query tensor; shape :math:`(N, ..., L, E)`.
key (Tensor): Key tensor; shape :math:`(N, ..., S, E)`.
value (Tensor): Value tensor; shape :math:`(N, ..., S, Ev)`.
attn_mask (optional Tensor): Attention mask; shape :math:`(N, ..., L, S)`. Two types of masks are supported.
A boolean mask where a value of True indicates that the element *should* take part in attention.
A float mask of the same type as query, key, value that is added to the attention score.
dropout_p (float): Dropout probability; if greater than 0.0, dropout is applied
is_causal (bool): If true, assumes causal attention masking and errors if both attn_mask and is_causal
are set.
scale (optional float): Scaling factor applied prior to softmax. If None, the default value is set
to :math:`\frac{1}{\sqrt{E}}`.
Returns:
output (Tensor): Attention output; shape :math:`(N, ..., L, Ev)`.
"""
pass

SPDA实现了attention模块最核心的部分(缩放的点乘注意力),这个函数等价于以下代码:

scale_factor = 1 / math.sqrt(Q.size(-1)) if scale is None else scale
attn_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) if is_causal else attn_mask
attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf')) if attn_mask.dtype==torch.bool else attn_mask
attn_weight = torch.softmax((Q @ K.transpose(-2, -1) * scale_factor) + attn_mask, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p)
return attn_weight @ V

这个函数也已经嵌入了PyTorch现有的Transformer API中,这就是说你在构建模型时直接使用torch.nn.MultiheadAttentiontorch.nn.TransformerEncoderLayer模块就可以看到SPDA带来的性能加速。当然,如果你需要定制化功能,那么你可以直接用这个函数来创建自己的attention模块。

SPDA之所以能带来性能的加速,主要是它背后已经实现了优化的kernels,目前SPDA支持三种kernels:

  • sdpa_flash:FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

  • sdpa_mem_eff: Memory-Efficient Attention

  • sdpa_math:A PyTorch implementation defined in C++

其中sdpa_flash支持在SM80+架构的GPUs上使用FP16精度训练和推理,而sdpa_mem_eff支持在大部分GPUs上采用FP16和FP32精度训练和推理。如果上述两个kernel都不支持的话,那么就只能采用sdpa_math了,它是直接基于C++的通用实现。默认情况下,这三个kernel都是开启的,当你调用SDPA时,它将根据你的输入选择一个最优的kernel来进行执行。

大部分情况下,我们不需要关注背后具体所选择的kernel,因为它背后已经做了最优的选择。但是如果你想显式控制所使用的kernel,那么可以采用torch.backends.cuda.sdp_kernel()来关闭具体的kernels,它是一个上下文管理器,比如我们要关闭sdpa_math,那么可以这样调用:

query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
with torch.backends.cuda.sdp_kernel(enable_math=False):
F.scaled_dot_product_attention(query, key, value)

由于sdpa_math被关闭,那么此时系统只能从sdpa_flash和sdpa_mem_eff这个两个kernel进行选择了。当你关闭两个kernel,那么就等同于直接选择使用剩下的那个kernel来进行实现了,比如下面的代码就相当于显式采用sdpa_mem_eff这个kernel了:

query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=True):
F.scaled_dot_product_attention(query, key, value)

不过,如果你当前的平台不支持这个kernel,那么将会报错:

RuntimeError: No available kernel.  Aborting execution.

这里我们可以使用sdp_kernel这个工具来比较不同的kernels下的计算时间,具体的代码如下:

import torch
import torch.utils.benchmark as benchmark
from torch.backends.cuda import sdp_kernel, SDPBackend
import torch.nn.functional as F

# Lets define a helpful benchmarking function:
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return t0.blocked_autorange().mean * 1e6

# Lets define the hyper-parameters of our input
batch_size = 32
max_sequence_len = 1024
num_heads = 32
embed_dimension = 32

dtype = torch.float16
device = "cuda"

query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)

print(f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")

# Lets explore the speed of each of the 3 implementations

# Helpful arg mapper
backend_map = {
SDPBackend.MATH: {"enable_math": True, "enable_flash": False, "enable_mem_efficient": False},
SDPBackend.FLASH_ATTENTION: {"enable_math": False, "enable_flash": True, "enable_mem_efficient": False},
SDPBackend.EFFICIENT_ATTENTION: {
"enable_math": False, "enable_flash": False, "enable_mem_efficient": True}
}

with sdp_kernel(**backend_map[SDPBackend.MATH]):
print(f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")

with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
try:
print(f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
except RuntimeError:
print("FlashAttention is not supported. See warnings for reasons.")

with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
try:
print(f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
except RuntimeError:
print("EfficientAttention is not supported. See warnings for reasons.")

在V100机器上的运行结果如下所示:

The default implementation runs in 6569.854 microseconds
The math implementation runs in 16091.686 microseconds
<timeit-src>:6: UserWarning: Memory efficient kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.h:527.)
<timeit-src>:6: UserWarning: Memory Efficient attention has been runtime disabled. (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.h:338.)
<timeit-src>:6: UserWarning: Flash attention kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.h:529.)
<timeit-src>:6: UserWarning: Flash attention only supports sm75 and sm8x gpu architectures. Attempting to run on a sm 7.0 gpu. (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.h:352.)
FlashAttention is not supported. See warnings for reasons.
The memory efficient implementation runs in 6595.339 microseconds

好吧,V100卡属于sm 7.0,不支持Flash attention,但是我们可以看到默认采用的kernel是sdpd_mem_eff,它相比sdpd_math,速度提升非常明显(6ms vs 16ms)。当我们把机器换成A100后,运行结果如下所示:

The default implementation runs in 2831.521 microsecondsThe math implementation runs in 7001.696 microsecondsThe flash attention implementation runs in 2829.635 microsecondsThe memory efficient implementation runs in 3011.410 microseconds

A100卡上是支持Flash attention,而且默认的实现方式是sdpa_flash,此时运行时间最短,A100比V100快了2倍多。

最后,我们再来看一下具体的实例,那就是基于SPDA对diffusers中的stable diffusion进行加速,目前diffusers中已经实现了基于scaled_dot_product_attention的AttnProcessor2_0:

class AttnProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
inner_dim = hidden_states.shape[-1]

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

query = attn.to_q(hidden_states)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states

这里我们以stable diffusion 1.5为例,首先我们将attention processor设置为默认的CrossAttnProcessor:

import torch
from diffusers import StableDiffusionPipeline
from diffusers.models.cross_attention import AttnProcessor2_0, CrossAttnProcessor

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
pipe.unet.set_attn_processor(CrossAttnProcessor())

prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]

这里在V100上运行的时间大约是3.6s(A100下是1.9s),显存最大占用约5.9GB。然后,我们将attention processor替换为AttnProcessor2_0:

pipe.unet.set_attn_processor(CAttnProcessor2_0())

加速后的运行时间大约是3s(A100下是1.6s),显存最大占用为4.7GB,可以看到我们不仅实现了加速,而且显存消耗也减少了。

另外,PyTorch 2.0也引入了torch.compile()来对模型进行加速,这里我们也可以在SPDA的基础上使用它来进一步来加速:

import torch
from diffusers import StableDiffusionPipeline
from diffusers.models.cross_attention import AttnProcessor2_0, CrossAttnProcessor

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to(
"cuda"
)
pipe.unet.set_attn_processor(AttnProcessor2_0()) # 其实默认会采用这个
pipe.unet = torch.compile(pipe.unet)

batch_size = 8
prompt = "A photo of an astronaut riding a horse on marse."
images = pipe(prompt, num_inference_steps=steps, num_images_per_prompt=batch_size).images

这里我在batch_size=8下,跑出来运行时间大约是16s(A100下是6.6s),而只采用SPDA的版本运行时间约17s(A100下是7.3s),还是有一定的加速效果的(不过V100相比A100还是太弱了)。

注意,我们这里的比较并不是严谨的,其实PyTorch官方也已经进行了系统的评测,具体可以见博客Accelerated Diffusers with PyTorch 2.0。

参考

  • https://huggingface.co/docs/diffusers/v0.13.0/en/optimization/torch2.0

  • https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html

  • https://pytorch.org/blog/accelerated-diffusers-pt-20/

  • https://pytorch.org/blog/accelerated-pytorch-2/



推荐阅读

深入理解生成模型VAE

DropBlock的原理和实现

SOTA模型Swin Transformer是如何炼成的!

有码有颜!你要的生成模型VQ-VAE来了!

集成YYDS!让你的模型更快更准!

辅助模块加速收敛,精度大幅提升!移动端实时的NanoDet-Plus来了!

SimMIM:一种更简单的MIM方法

SSD的torchvision版本实现详解


机器学习算法工程师


                                    一个用心的公众号


继续滑动看下一个

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存