[LLM 시간 및 메모리 최적화] 1. Flash Attention

2025. 2. 24. 14:03ML/NLP

0. Background: Tiling

Tiling 기법은 attention 연산 시 각 행렬($\mathbf Q, \mathbf K, \mathbf V$)을 블록 단위로 처리함으로써, 메모리 및 연산 효율성을 높이는 기법이다. 

1. Standard Attention

해당 알고리즘에서는 $O(N^2)$ matrix인 $\mathbf S$와 $\mathbf P$가 각각 한번 씩 총 2번 write/read된다. 

아래 순서도와 같이 $O(N^2)$ I/O 연산 없이 $\mathbf O$를 도출할 수 없을까?

기존 알고리즘에서 $\mathbf S$로 $\mathbf P$를 계산하기 위해서는, 각 행(row)별 최댓값(row-wise max value)이 필요하다. 즉, 각 행의 모든 요소가 필요하다. 그렇기 때문에, $\mathbf S$의 write/read는 불가피하다. 

2. Flash Attention

Flash Attention은 Tiling Softmax 기법을 활용해 $\mathbf S, \mathbf P$를 메모리에 저장하지 않고, $\mathbf O$를 바로 계산(업데이트)한다.

 

기존 softmax 연산 과정은 다음과 같다.

$$m(x) := \max_i x_i, \quad f(x) := \left[ e^{x_1 - m(x)} \ \dots \ e^{x_B - m(x)} \right], \quad \ell(x) := \sum_i f(x)_i, \quad \text{softmax}(x) := \frac{f(x)}{\ell(x)}.$$

Tiling softmax 연산은 현재 블록의 계산 결과 $x_i$과 이전 결과값($x_{1:i-1}$)로 계산해 놓은 것들을 이용해, $\mathbf O$를 업데이트 하는 방식으로 진행된다. 이때, 필요한 정보는 $m, \ell$이며, 크기는 각각 $O(N)$으로 $\mathbf S$보다 훨씬 작다.

$$m(x_{1:i}) = m \left( \begin{bmatrix} x_{1:i-1} & x_i \end{bmatrix} \right) = \max(m(x_{1:i-1}), m(x_i)) \\ \quad \\ f(x_{1:i}) = \begin{bmatrix} e^{m(x_{1:i-1}) - m(x_{1:i})} f(x_{1:i-1}) & e^{m(x_i) - m(x_{1:i})} f(x_i) \end{bmatrix}, \quad \left(f(x_i) = e^{x_i - m(x_i)}\right) \\ \quad \\ \ell(x_{1:i}) = \ell \left( \begin{bmatrix} x_{1:i-1} & x_i \end{bmatrix} \right) = e^{m(x_{1:i-1}) - m(x_{1:i})} \ell(x_{1:i-1}) + e^{m(x_i) - m(x_{1:i})} \ell(x_i), \quad \left(l(x_i) = e^{x_i - m(x_i)}\right) \\ \quad \\ \text{softmax}(x_{1:i}) = \frac{f(x_{1:i})}{\ell(x_{1:i})}. \\ \quad \\ \mathbf{O}_{1:i} = \ell(x_{1:i})^{-1} \left( \ell(x_{1:i-1}) e^{m(x_{1:i-1}) - m(x_{1:i})} \mathbf{O}_{1:i-1} + e^{m(x_i) - m(x_{1:i})} f(x_i) \mathbf V_i \right)$$

3. Flash Attention 2

3.1. Computational Efficiency

Flash Attention 1에서는 $\mathbf K, \mathbf V$가 외부 루프(outer loop)에 위치하고, $\mathbf Q$가 내부 루프(inner loop)에서 처리된다.

그렇기 때문에, 내부 루프가 한 번 실행될 때마다 $\mathbf O$의 모든 행(row)이 계산(업데이트)되므로, 내부 루프에서 역정규화($\text{diag}(\ell_i)$)와 정규화($\text{diag}(\ell_i^\text{new})^{-1}$)를 반복적으로 수행해야 한다.

 

하지만, Flash Attention 2에서는 $\mathbf Q$가 외부 루프(outer loop)에 위치하고, $\mathbf K, \mathbf V$가 내부 루프(inner loop)에서 처리된다.

그러면, $i$번째 내부 루프는 오직 $\mathbf O$의 $i$ 행(row)만 계산할 뿐만 아니라 계산이 즉시 완료된다. 이러면 굳이 역정규화와 정규화를 반복적으로 수행할 필요가 없다. 내부 루프가 끝날 때마다 해당 행의 정규화만 해주면 되므로, 연산 효율성을 높아졌다.

 

3.2. Memory Efficiency

$i$번째 내부 루프에는 오직 $\mathbf O$의 $i$ 행의 계산만 수행 및 완료하기 때문에, $m(x_i), \ell(x_i)$를 모두 저장하지 않고, 역전파를 위한 로그합지수(log sum exp)만 저장해 메모리 효율성을 개선했다.

$$L(x_i) = m(x_i) +\log(\ell(x_i))$$

 

 

'ML > NLP' 카테고리의 다른 글

[LLM] DeepSeek-V3  (0) 2025.02.28
[LLM 시간 및 메모리 최적화] 2. KV Cache & Paged Attention  (0) 2025.02.26
[NLP] Transformer  (2) 2024.08.19
[NLP] RNN, LSTM, Attention  (0) 2024.03.31