(논문 요약) FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision (Paper)
핵심 내용
- Nvidia Hopper H100 GPU 에서 효율적인 연산을 하기 위한 기법들 (최대한 parallel 하게 연산)
- Producer-Consumer asynchrony: split producers and consumers of data into separate warps
- Hiding softmax under asynchronous block-wise GEMMs (General Matrix Multiply): overlap low-throughput non-GEMM operations in softmax with asynchronous WGMMA (Warpgroup Matrix-Multiply-Accumulate) instructions for GEMM.
- Hardware-accelerated low-precision GEMM: make use of FP8 Tensor Cores for GEMM
배경 지식
- Memory hierarchy
- Global memory (GMEM): a.k.a HBM, 각 streaming multiprocessor (SM) 가 접근 가능한 off-chip DRAM
- shared memory (SMEM): 각 streaming multiprocessor 내부의 on-chip cache
- Register file (RMEM): 각 thread 내부에 존재
- Thread hierarchy
- warps: 32 threads
- warpgroups: 4 contiguous warps
- threadblocks: a.k.a cooperative thread arrays or CTAs
- Threads in the same CTA are co-scheduled on the same SM
- CTAs in the same cluster are co-scheduled on the same GPC
알고리즘
- Producer-Consumer asynchrony: pingpong 으로 서로 다른 warpgroup 이 softmax 와 GEMMs (General Matrix Multiply) 교차
- producer: HBM 에서 shared memory 로 데이터 복사
- consumer: 복사된 데이터로 연산
- consumer 의 warpgroup 내에서 추가적인 병렬화
- 위의 알고리즘에서 local softmax 연산 (line18-19) 는 $S^{(j)}_i$ 가 끝나야함
- 해당 local softmax 결과는 line21 에서 사용됨
- 위의 2개 operation 을 쪼갬
- softmax 연산과 WGMMA (Warpgroup Matrix-Multiply-Accumulate) 을 교차
실험 결과