003-注意力机制详解:从基础Attention到DeepSeek的优化策略

张开发
2026/4/11 17:12:13 15 分钟阅读

分享文章

003-注意力机制详解:从基础Attention到DeepSeek的优化策略
003-注意力机制详解从基础Attention到DeepSeek的优化策略上周调一个多模态模型输入序列稍微长点显存就炸了。profile工具显示attention层的计算复杂度曲线陡得吓人——典型的O(n²)问题。这让我想起几年前第一次实现Transformer时那个朴素的attention实现现在看简直像古董。今天咱们就聊聊attention这些年是怎么进化过来的特别是像DeepSeek这类模型做了哪些实在的优化。一、从那个“古老”的Scaled Dot-Product说起最早Transformer论文里的attention公式现在看都背下来了# 经典实现教学用生产别这么写defattention_naive(Q,K,V):scorestorch.matmul(Q,K.transpose(-2,-1))# [batch, head, seq_len, seq_len]scoresscores/math.sqrt(d_k)# scalingattn_weightstorch.softmax(scores,dim-1)returntorch.matmul(attn_weights,V)问题太明显了那个scores矩阵是seq_len的平方大小。512长度还能忍到2048的时候单是这一个中间变量就能吃掉几个G的显存。更麻烦的是计算softmax需要保留整个矩阵在内存里反向传播时还得再存一份。实际部署时第一个优化就是分块计算。但分块也有坑softmax的数值稳定性。直接对分块的结果做softmax再合并结果对不上。这里我们一般用online softmax技巧defsafe_softmax(x):# 减最大值防止溢出老司机都懂x_maxx.max(dim-1,keepdimTrue).values exp_xtorch.exp(x-x_max)returnexp_x/exp_x.sum(dim-1,keepdimTrue)这个操作在分块计算时必须每块都做还得记录全局最大值——稍微麻烦点但能省30%以上显存。二、FlashAttention的革命把IO意识带入算法2022年看到FlashAttention论文时有种“早该这么想了”的感觉。它的核心洞察很硬件对于现代GPU计算速度远快于内存读写瓶颈在IO。传统的attention实现反复在HBM和SRAM之间搬运数据大部分时间在等数据。FlashAttention的做法是把计算拆成Tile让每个Tile的数据在SRAM里完成所有操作只写回最终结果。伪代码简化版# 概念示意真实实现要处理mask、dropout等forblock_iinrange(num_blocks_q):Qiload_tile(Q,block_i)acczeros_like(output_tile)max_vec-inf sum_veczerosforblock_jinrange(num_blocks_k):Kj,Vjload_tile(K,block_j),load_tile(V,block_j)# 在SRAM里计算这个小块scores_ijmatmul(Qi,Kj.T)new_maxelementwise_max(max_vec,scores_ij.max())# 调整之前累积的权重关键scaleexp(max_vec-new_max)accacc*scale.unsqueeze(-1)sum_vecsum_vec*scale exp_scoresexp(scores_ij-new_max)accmatmul(exp_scores,Vj)sum_vecexp_scores.sum(dim-1)max_vecnew_max output_tileacc/sum_vec.unsqueeze(-1)write_back(output_tile)这个算法把HBM访问量从O(seq_len²)降到O(seq_len)。第一次在项目里换上FlashAttention同样的3090显卡序列长度能从2K推到8K——效果立竿见影。三、DeepSeek的注意力优化工程上的组合拳看DeepSeek的技术报告他们的attention优化是组合策略。几个值得说的点1. 混合精度策略不是简单用amp。他们的做法是QK计算用FP16softmax用FP32累积最后乘V转回FP16。为什么因为attention scores的数值范围动态太大纯FP16容易溢出或精度不够。但全程FP32又太慢。这个平衡点调了很久我们团队实测能比纯FP16训练稳定比纯FP32快40%。2. 稀疏注意力滑动窗口对于长文本完全稠密的attention没必要。DeepSeek用了块稀疏滑动窗口。代码里大概这样# 滑动窗口attention局部注意力defsliding_window_attention(q,k,v,window_size):# 只计算每个query附近window_size内的key# 实现时用banded matrix乘法别傻傻的生成大矩阵再maskpass但这里有个坑直接硬mask会破坏训练稳定性。他们的做法是给mask外的位置加一个很大的负偏置比如-1e4而不是直接置零。这样梯度还能流动只是权重极小。3. KV Cache的极致优化推理时的KV Cache他们做了内存复用。同一个batch里不同序列长度共享一块预分配的内存池用offset来区分。这个技巧在部署时特别有用classKVCachePool:def__init__(self,total_size,head_dim):self.k_cachetorch.empty(total_size,head_dim)# 预分配一大块self.v_cachetorch.empty(total_size,head_dim)self.offset0defallocate(self,seq_len):startself.offset self.offsetseq_lenreturnself.k_cache[start:startseq_len],self.v_cache[start:startseq_len]避免频繁分配释放内存碎片少了速度自然上来。四、一些踩坑经验关于LayerNorm的位置Transformer里LayerNorm放attention前还是后原始论文是后置但很多新模型包括DeepSeek用前置。实测前置训练更稳梯度更好。但推理时如果想做算子融合后置更方便。看需求取舍。Dropout的放置attention里的dropout有三种位置QK乘积后、softmax后、最后乘V后。我们实验发现在softmax后dropout效果最好但会影响激活稀疏性。如果追求推理速度可以只在训练时用推理时去掉。RoPE位置编码的陷阱旋转位置编码(RoPE)现在很流行但实现时有细节坑# 错误实现别这么写defrope_wrong(x,freqs):returnx*cos(freqs)rotate(x)*sin(freqs)# rotate实现不对会破坏梯度# 正确实现要保证复数旋转的线性性defapply_rope(q,k,freqs):# 实际代码较长关键是保持复数乘法形式pass建议直接用开源实现自己写容易出数值问题。五、个人建议不要过早优化先确保模型正确性profile找到真实瓶颈再加优化。我见过有人一上来就写FlashAttention结果mask处理错debug三天。保持可读性优化时加详细注释特别是数学变换部分。三个月后你自己都看不懂那堆reshape和转置是干嘛的。测试覆盖边界长序列、短序列、全mask、部分mask、不同batch size都要测。attention的边界条件特别多。硬件感知了解你的部署硬件。A100和H100的tensor core用法不同甚至不同CUDA版本都有差异。借鉴但别盲从DeepSeek的优化是针对他们的架构和数据。你的任务可能不需要那么复杂的策略。有时候简单的window attention加好用的KV cache效果足够好。注意力机制从理论到生产中间隔着一堆工程细节。每次觉得“这次应该没问题了”总会有新的序列长度或batch size让你重新思考。或许这就是做模型的乐趣——永远有更好的方案等着去实现。下一篇预告004、位置编码演进从Sinusoidal到RoPE的深度剖析

更多文章