Kernel Fusion and Flash Attention #
Memory Hierarchy #
Modern accelerators (GPUs, TPUs, AWS Trainium/Inferentia) employ a multi-level memory hierarchy, with memories organized from fastest/smallest to slowest/largest:
- Registers / On-chip SRAM (e.g., GPU shared memory, NeuronCore SBUF): Highest bandwidth (~20x higher than HBM), lowest latency, but limited capacity (tens of MB per compute unit)
- Device Memory / HBM: Large capacity (tens of GB), high bandwidth compared to CPU memory, but still orders of magnitude slower than on-chip SRAM
- Host Memory (CPU DRAM): Largest capacity, lowest bandwidth from accelerator’s perspective
The memory wall refers to the growing gap between compute throughput and memory bandwidth. While compute capabilities have scaled dramatically (e.g., a single AWS Trn2 NeuronCore delivers ~335 TFLOPs BF16), memory bandwidth has not kept pace.
Example: NeuronCore (AWS Trainium) Programming Model #
AWS Trainium’s NeuronCore architecture provides a concrete example of software-managed memory hierarchy. Unlike GPUs with hardware-managed caches, NKI (Neuron Kernel Interface) requires programmers to explicitly control data movement between HBM and on-device memory buffers.
NeuronCore memory hierarchy showing capacity and bandwidth at each level
NeuronCore exposes a 4-level memory hierarchy:
| Memory Level | Capacity | Bandwidth |
|---|---|---|
| PSUM (on-chip) | ~2 MB | ~10 TB/sec |
| SBUF (on-chip) | ~25 MB | ~10 TB/sec |
| HBM (device) | ~50 GB | ~0.5 TB/sec per NC |
| Host DRAM | ~1 TB | ~16 GB/sec |
The key observation: on-chip SBUF provides 20x higher bandwidth than HBM (10 TB/s vs 0.5 TB/s). With no hardware cache to automatically manage data movement, the programmer must explicitly manage data transfers between memory tiers (shown as “Refill” and “Spill” in the diagram). This explicit memory model makes kernel fusion essential: fusing multiple operations keeps intermediate data in fast SBUF instead of incurring expensive HBM round-trips between each operation.
Arithmetic Intensity #
Due to the memory wall, many deep learning operations are memory-bound rather than compute-bound — the accelerator spends more time waiting for data than performing arithmetic. Arithmetic intensity measures how much computation is performed per byte of memory transferred:
$$\text{Arithmetic Intensity} = \frac{\text{FLOPs}}{\text{Bytes Transferred}}$$
The units are FLOPs/byte. This metric determines whether an operation is compute-bound or memory-bound via the roofline model. Every accelerator has a ridge point defined by:
$$\text{Ridge Point} = \frac{\text{Peak Compute (FLOPs/s)}}{\text{Peak Bandwidth (Bytes/s)}}$$
For example, an AWS Trn2 NeuronCore has ~335 TFLOPs (BF16) and ~1.45 TB/s HBM bandwidth, giving a ridge point of ~230 FLOPs/byte. Operations with arithmetic intensity below this threshold are memory-bound; those above are compute-bound. Most LLM inference operations—especially during the decode phase where batch sizes are small—have low arithmetic intensity and are memory-bound. This is why kernel fusion is critical: by keeping intermediate data in fast on-chip memory, we reduce memory traffic and shift operations closer to being compute-bound.
What is Kernel Fusion? #
Kernel fusion is a compiler/runtime optimization that combines multiple operations into a single kernel to minimize memory traffic between the accelerator’s on-chip memory and device memory (HBM).
Without fusion, each operation in a computation graph:
- Reads input tensors from HBM to on-chip memory
- Performs computation
- Writes output tensors back to HBM
With fusion, intermediate results are kept in fast on-chip memory (SRAM/registers), eliminating redundant HBM round-trips. This is particularly impactful because:
- On-chip SRAM provides ~20x higher bandwidth than HBM
- Reducing HBM accesses directly improves performance for memory-bound operations
- Fused kernels can leverage tiling to process large tensors in cache-friendly blocks
Standard attention without fusion: each operation (S = Q × Kᵀ, P = softmax(S), O = P × V) requires loading inputs from HBM and storing outputs back to HBM, resulting in multiple expensive memory round-trips.
Arithmetic intensity of standard attention. For sequence length $N$, head dimension $d$, and FP16 precision (2 bytes per element), we can compute the memory traffic for each step:
- $S = QK^T$: Read $Q$ ($2Nd$ bytes), read $K$ ($2Nd$ bytes), write $S$ ($2N^2$ bytes)
- $P = \mathrm{softmax}(S)$: Read $S$ ($2N^2$ bytes), write $P$ ($2N^2$ bytes)
- $O = PV$: Read $P$ ($2N^2$ bytes), read $V$ ($2Nd$ bytes), write $O$ ($2Nd$ bytes)
Total memory traffic: $8Nd + 8N^2$ bytes. The two matrix multiplications contribute $4N^2d$ FLOPs, while softmax adds $\sim 5N^2$ FLOPs. This gives:
$$\text{Arithmetic Intensity} = \frac{4N^2d}{8Nd + 8N^2} = \frac{Nd}{2(d + N)}$$
For $N = 4096$ and $d = 128$: arithmetic intensity $\approx 62$ FLOPs/byte — well below Trn2’s ridge point of ~230 FLOPs/byte. Standard attention is memory-bound, making it an ideal candidate for kernel fusion.
Fusion Techniques #
Operator Fusion #
On NeuronCore, each operation (matmul, activation, normalization) is typically implemented as a separate NKI kernel. Each kernel reads inputs from HBM into SBUF, computes, and writes outputs back to HBM. For a sequence of $k$ operations on a tensor of size $n$ bytes, this means $2kn$ bytes of HBM traffic — each intermediate result is spilled to HBM and then refilled.
Operator fusion combines multiple operations into a single kernel. The fused kernel loads inputs once, performs all computations while keeping intermediates in SBUF, and writes only the final output. This reduces HBM traffic to $2n$ bytes — a $k\times$ reduction.
Example: Consider computing $y = \mathrm{GELU}(Wx + b)$. Without fusion:
- Kernel 1: Refill $W, x$ → compute $z = Wx$ → spill $z$ to HBM
- Kernel 2: Refill $z, b$ → compute $z’ = z + b$ → spill $z’$ to HBM
- Kernel 3: Refill $z’$ → compute $y = \mathrm{GELU}(z’)$ → spill $y$ to HBM
With fusion, a single kernel computes $y = \mathrm{GELU}(Wx + b)$ directly: refill $W, x, b$ once, compute everything with $z, z’$ in SBUF, spill $y$ once. Memory traffic drops from $6n$ to $2n$ bytes.
The key constraint: all intermediate data must fit in SBUF. This works well when intermediates are the same size as inputs (element-wise ops, reductions), but fails when operations expand data size.
Tiling #
When intermediate data exceeds SBUF capacity, tiling partitions tensors into blocks that fit in SBUF, applying fusion within each tile.
Example: For $C = g(f(A))$ where $B = f(A)$ is too large for SBUF:
- Unfused: Compute all of $B$, spill to HBM, then compute all of $C$ by refilling $B$
- Tiled: For each tile $i$: refill $A_i$ → compute $B_i$ in SBUF → compute $C_i$ → spill $C_i$
The intermediate $B_i$ never touches HBM. This achieves the same $k\times$ memory traffic reduction as operator fusion, but requires that $C_i$ depends only on $A_i$ (the computation must be tileable).
Flash Attention and Mathematical Fusion #
Why Standard Fusion Fails for Attention #
Standard self-attention computes: $$S = QK^T, \quad P = \mathrm{softmax}(S), \quad O = PV$$
Without fusion, each step requires HBM round-trips:
- Load $Q, K$ from HBM → compute $S$ → store $S$ to HBM
- Load $S$ from HBM → compute $P$ → store $P$ to HBM
- Load $P, V$ from HBM → compute $O$ → store $O$ to HBM
The attention matrix $S$ has size $N \times N$. For sequence length $N = 8192$, this is 67M elements per head — far exceeding SBUF capacity (~25 MB). Can we tile to avoid materializing the full matrix?
The softmax barrier. Consider computing one row of output $o_i = \sum_j p_{ij} v_j$ where $p_{ij} = [\mathrm{softmax}(s_i)]_j$. If we tile over $K$ and $V$ (processing columns in blocks), we compute partial scores $s_{ij}$ for $j \in B_k$ in each block. But softmax couples all columns:
$$[\mathrm{softmax}(s_i)]_j = \frac{e^{s_{ij}}}{\sum_{k=1}^{N} e^{s_{ik}}}$$
The denominator requires all $N$ scores in the row. We cannot compute any $p_{ij}$ until we have seen every $s_{ik}$.
Numerical stability makes it worse. Naive softmax overflows when scores are large. The numerically stable version subtracts the row maximum:
$$m_i = \max_j s_{ij}, \quad \ell_i = \sum_{j=1}^{N} e^{s_{ij} - m_i}, \quad p_{ij} = \frac{e^{s_{ij} - m_i}}{\ell_i}$$
This requires two passes: first to find $m_i$, then to compute exponentials and sum. Both $m_i$ and $\ell_i$ depend on all $N$ elements, so we must materialize the entire row of $S$ before computing any element of $P$.
The fusion blocker. This creates a dependency chain that prevents tiling:
- Cannot compute $p_{ij}$ without knowing $m_i$ (need full row of $S$)
- Cannot compute $m_i$ without all $s_{ij}$ (need to finish $S = QK^T$ first)
Result: we must write the full $N \times N$ matrix $S$ to HBM, then read it back for softmax. Standard fusion and tiling cannot break this logical dependency.
Online Softmax: The Mathematical Transform #
The key innovation is online softmax — a mathematically equivalent reformulation that computes numerically stable softmax incrementally in a single pass, removing the fusion blocker.
Recall the numerically stable softmax requires two passes: $$m = \max_j s_j, \quad \ell = \sum_j e^{s_j - m}, \quad [\mathrm{softmax}(s)]_j = \frac{e^{s_j - m}}{\ell}$$
Online softmax processes elements in blocks, maintaining running statistics $(m, \ell)$ that can be updated incrementally while preserving numerical stability.
Derivation: Suppose we have processed blocks $1, \ldots, i$ with running max $m^{(i)}$ and running sum $\ell^{(i)} = \sum_{j=1}^{i} \sum_{k \in B_j} e^{s_k - m^{(i)}}$. When we encounter block $i+1$ with elements $\lbrace s_k \rbrace_{k \in B_{i+1}}$:
-
Update max: $$m^{(i+1)} = \max(m^{(i)}, \max_{k \in B_{i+1}} s_k)$$
-
Update sum: The previous sum $\ell^{(i)}$ was computed with offset $m^{(i)}$, but we need all terms to use the new offset $m^{(i+1)}$ for numerical stability. Using $e^{s_k - m^{(i)}} = e^{s_k - m^{(i+1)}} \cdot e^{m^{(i+1)} - m^{(i)}}$, we rescale: $$\ell^{(i+1)} = e^{m^{(i)} - m^{(i+1)}} \cdot \ell^{(i)} + \sum_{k \in B_{i+1}} e^{s_k - m^{(i+1)}}$$
The rescaling factor $e^{m^{(i)} - m^{(i+1)}} \leq 1$ (since $m^{(i+1)} \geq m^{(i)}$) ensures all exponents remain non-positive, maintaining numerical stability throughout.
Correctness: After processing all blocks, $\ell^{(\mathrm{final})} = \sum_k e^{s_k - m^{(\mathrm{final})}}$, which is exactly the denominator needed for numerically stable softmax.
Extending to attention output: For $o = \sum_k [\mathrm{softmax}(s)]_k \cdot v_k$, we maintain a running (unnormalized) output: $$o^{(i)} = \sum_{j=1}^{i} \sum_{k \in B_j} e^{s_k - m^{(i)}} \cdot v_k$$
When processing block $i+1$, rescale the previous output to use the new max and add new contributions: $$o^{(i+1)} = e^{m^{(i)} - m^{(i+1)}} \cdot o^{(i)} + \sum_{k \in B_{i+1}} e^{s_k - m^{(i+1)}} \cdot v_k$$
After all blocks: $o^{(\mathrm{final})} / \ell^{(\mathrm{final})} = \sum_k [\mathrm{softmax}(s)]_k \cdot v_k$ — mathematically identical to standard attention, but computed in a single pass with guaranteed numerical stability.
Flash Attention: Fused Attention Kernel #
Flash Attention combines online softmax with tiling to fuse all attention operations into a single kernel:
Algorithm: Flash Attention
Input: $Q, K, V \in \mathbb{R}^{N \times d}$, block sizes $B_r, B_c$ where $B_r \cdot d + B_c \cdot d + B_r \cdot B_c \leq M$
Output: $O = \mathrm{softmax}(QK^T)V$
- Divide $Q$ into row blocks $Q_1, \ldots, Q_{N/B_r}$ of size $B_r \times d$
- Divide $K, V$ into row blocks of size $B_c \times d$
- for $i = 1$ to $N/B_r$ do
- Initialize: $m_i \leftarrow -\infty$, $\ell_i \leftarrow 0$, $O_i \leftarrow 0$
- for $j = 1$ to $N/B_c$ do
- Load $Q_i, K_j, V_j$ from HBM to SBUF
- Compute $S_{ij} \leftarrow Q_i K_j^T$
- Compute $\tilde{m}_{ij} \leftarrow \mathrm{rowmax}(S_{ij})$, $m_i^{\mathrm{new}} \leftarrow \max(m_i, \tilde{m}_{ij})$
- Compute $\tilde{P}_{ij} \leftarrow \exp(S_{ij} - m_i^{\mathrm{new}})$
- Update $\ell_i \leftarrow e^{m_i - m_i^{\mathrm{new}}} \ell_i + \mathrm{rowsum}(\tilde{P}_{ij})$
- Update $O_i \leftarrow e^{m_i - m_i^{\mathrm{new}}} O_i + \tilde{P}_{ij} V_j$
- Set $m_i \leftarrow m_i^{\mathrm{new}}$
- end for
- Write $O_i \leftarrow O_i / \ell_i$ to HBM
- end for
- return $O$
Flash Attention fuses all operations into a single kernel: Q, K, V are refilled once from HBM, intermediate results S and P stay in fast on-chip SBUF, and only the final output O is spilled back to HBM.
Performance Analysis #
IO Complexity of Standard Attention: Reading $Q, K, V$ costs $O(Nd)$. Materializing and reading back the $N \times N$ matrices $S$ and $P$ costs $O(N^2)$. Total: $O(Nd + N^2)$.
IO Complexity of Flash Attention: We count HBM accesses for each tensor:
- $Q$: Each block $Q_i$ ($B_r \times d$ elements) is loaded once per outer loop iteration → $(N/B_r) \cdot B_r d = Nd$ total
- $K, V$: Each block $K_j, V_j$ ($B_c \times d$ elements) is loaded once per inner loop, repeated for all outer iterations → $(N/B_r) \cdot (N/B_c) \cdot B_c d = N^2 d / B_r$ each
- $O$: Each block $O_i$ is written once → $Nd$ total
Total HBM accesses: $Nd + 2 \cdot \frac{N^2 d}{B_r} + Nd = O\left(Nd + \frac{N^2 d}{B_r}\right)$
From the SRAM constraint $B_r \cdot d + B_c \cdot d + B_r \cdot B_c \leq M$, we have $B_r = O(M/d)$. Substituting:
$$\text{Flash Attention IO} = O\left(Nd + \frac{N^2 d}{M/d}\right) = O\left(Nd + \frac{N^2 d^2}{M}\right)$$
For long sequences where $N^2 d^2 / M \gg Nd$, this simplifies to $O(N^2 d^2 / M)$ — a factor of $M/d^2$ improvement over standard attention’s $O(N^2)$ term.
Memory: Standard attention requires $O(N^2)$ memory to store $S$ and $P$. Flash Attention avoids materializing these matrices, requiring only $O(Nd)$ for inputs/outputs — enabling much longer sequences.
References #
- Dao, T., Fu, D. Y., Ermon, S., Rudra, A., & Ré, C. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS 2022. https://arxiv.org/abs/2205.14135
- AWS Neuron Documentation. Neuron Kernel Interface (NKI) Programming Guide. https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/nki/index.html