FlashAttention, Intuitively
Exact Attention Without the Giant Matrix
Most explanations of FlashAttention start too late.
They begin with GPU memory hierarchies, tiled matrix multiplication, SRAM, HBM, CUDA kernels, and benchmark charts. All of that matters, but it hides the central idea behind a wall of implementation detail.
The best way to understand FlashAttention is to start with a simpler question:
What if the attention matrix is not the thing we want, but only an expensive temporary object we accidentally learned to store?
That question is the thread that unravels the entire design.
FlashAttention does not invent a new attention mechanism. It does not approximate softmax attention. It does not make attention linear. For the same inputs, it computes the same result as ordinary scaled dot-product attention.
The difference is physical, not mathematical.
Ordinary attention makes the huge N x N attention matrix real in memory. FlashAttention keeps that matrix as a logical object, touches it block by block, and refuses to store it as a giant tensor.
The attention matrix still exists in the math.
It just stops existing as a huge intermediate allocation.
That is the core intuition.
If there is one sentence to keep in your head, it is this:
FlashAttention does not reduce the number of attention relationships. It reduces how long the intermediate representation of those relationships has to live in memory.
The Misleading Phrase: “Faster Attention”
“Faster attention” sounds like FlashAttention changed the formula.
That is the wrong mental model.
Standard attention is usually written as:
S = QK^T
P = softmax(S)
O = PVFlashAttention still computes this same O.
So the real question is sharper:
If the math is the same, where did the speed come from?
The answer is not hidden in the algebra. It is hidden in the lifetime of intermediate data.
Ordinary attention creates two large temporary matrices:
S: attention scores N x N
P: attention probabilities N x NFor short sequences, this is fine. For long sequences, these matrices dominate memory traffic.
And on modern GPUs, memory traffic is often the real bottleneck.
Ordinary Attention Builds a Hidden Table
To see the problem, stop reading QK^T as a compact formula.
Read it as a table-building operation.
Suppose a sequence has N tokens. Each token has a query vector and a key vector. Attention asks:
For each token, how much should it attend to every other token?
For query token i and key token j, the score is:
score(i, j) = Q[i] dot K[j]That gives one score for every pair of tokens:
key_1 key_2 key_3 ... key_N
query_1 . . . .
query_2 . . . .
query_3 . . . .
...
query_N . . . .This is the N x N score matrix.
Then softmax turns each row of scores into probabilities:
key_1 key_2 key_3 ... key_N
query_1 p p p p
query_2 p p p p
query_3 p p p p
...
query_N p p p pThis is the N x N probability matrix.
Only after that do we multiply by V to get the final output:
O = P VHere is the first major point of departure:
Attention logically requires
N x Nrelationships, but ordinary implementations physically storeN x Nintermediate results.
Those are not the same thing.
The N x N table is a mathematical dependency pattern. It does not have to be a physical storage plan.
The Real Bottleneck Is Memory Traffic
It is tempting to say ordinary attention is slow because QK^T has a lot of multiplications.
That is partly true, but not the most useful explanation.
Modern GPUs are extremely good at matrix multiplication. Their compute units are fast, especially when the work maps cleanly to tensor cores. What they are much worse at is repeatedly moving huge tensors between slow global memory and the compute units.
A simplified ordinary attention pipeline looks like this:
1. Read Q and K
2. Compute S = QK^T
3. Write S to HBM
4. Read S from HBM
5. Compute P = softmax(S)
6. Write P to HBM
7. Read P and V from HBM
8. Compute O = PV
9. Write O to HBMHBM is high-bandwidth GPU memory, but it is still slow compared with on-chip memory such as SRAM, shared memory, and registers.
The painful part is not merely that S and P are large. It is that they are large and temporary.
They are not the final answer. They are scaffolding.
Ordinary attention builds the scaffolding in global memory, then keeps walking back to it.
FlashAttention asks:
Can we compute the final output without ever writing the scaffolding down?
The Key Reframe: We Need the Output, Not the Matrix
Look at one row of attention.
For a single query token, attention computes scores against all keys:
scores = [s1, s2, s3, ..., sN]Then it computes:
output = softmax(scores) @ VIf we expand the softmax, ignoring numerical stability for a moment, this becomes:
output =
(exp(s1) * V1 + exp(s2) * V2 + ... + exp(sN) * VN)
/
(exp(s1) + exp(s2) + ... + exp(sN))This expression reveals something important.
To get the final output, we do not necessarily need to store every probability:
p1, p2, p3, ..., pNWe need two accumulated quantities:
weighted_sum = exp(s1) * V1 + exp(s2) * V2 + ... + exp(sN) * VN
normalizer = exp(s1) + exp(s2) + ... + exp(sN)Then:
output = weighted_sum / normalizerThis is the second major realization:
Attention probabilities are useful mathematically, but they are a liability physically. We do not need to store them.
If we can accumulate the numerator and denominator correctly, the probability vector can remain implicit.
That sounds simple. But softmax has one serious obstacle.
Why Softmax Seems to Break Blocking
In real implementations, we do not compute softmax as:
exp(score)Large scores can overflow. For numerical stability, softmax subtracts the maximum score in the row:
softmax(score_i) =
exp(score_i - max_score)
/
sum_j exp(score_j - max_score)This does not change the result because softmax only cares about relative differences. Subtracting the same value from every score preserves those differences.
But it creates a problem for blocking.
If we process the row in chunks, we do not know the final row maximum at the beginning.
For example:
block 1: [1000, 1001]
block 2: [1002, 999]After block 1, the maximum seen so far is 1001.
After block 2, the maximum becomes 1002.
So the old block was accumulated relative to the wrong reference point. We used 1001, but the final stable softmax should use 1002.
This is the exact point where naive blocking fails.
Softmax appears to require the whole row.
FlashAttention needs a way to scan the row block by block while still computing the same result as if it had seen the whole row at once.
That way is online softmax.
Online Softmax: Rescaling the Past
Online softmax works because of one simple property:
exp(a + b) = exp(a) * exp(b)Suppose we accumulated some old scores using old_m as the maximum:
exp(score - old_m)Later, we discover a larger maximum new_m.
Now the old scores should have been measured relative to new_m:
exp(score - new_m)Rewrite it:
score - new_m = score - old_m + old_m - new_mTherefore:
exp(score - new_m)
= exp(score - old_m) * exp(old_m - new_m)The crucial part is that:
exp(old_m - new_m)does not depend on score.
It is a shared scaling factor for all previously accumulated terms.
So when the maximum changes, we do not need to revisit every old score. We rescale the accumulated past.
For one row, FlashAttention keeps three running states:
m = running max
l = running normalizer
o = running weighted value sumThese are sufficient statistics for the row. They are not the probabilities themselves, but they contain enough information to produce the same final output.
Here:
l = sum exp(score_i - m)
o = sum exp(score_i - m) * V_iWhen a new block arrives:
new_m = max(old_m, block_max)
old_scale = exp(old_m - new_m)
block_weights = exp(block_scores - new_m)
new_l = old_l * old_scale + sum(block_weights)
new_o = old_o * old_scale + sum(block_weights * block_values)At the end:
output = o / lThis is the mathematical key that unlocks FlashAttention.
But it is important to say it precisely:
Online softmax is not “doing softmax separately on each block and stitching the blocks together.”
That would be wrong, because each block would have its own local normalization.
Online softmax is different:
It scans blocks one at a time while maintaining the same global softmax result that a full-row computation would have produced.
The past is not forgotten.
It is rescaled.
From One Row to Blocks of Rows
So far we have described one query token.
Real FlashAttention processes blocks of query tokens.
Instead of taking one query row:
q_iit takes a block:
Q_blockThen it scans blocks of K and V:
K_block, V_blockFor each pair of blocks, it computes:
S_block = Q_block @ K_block^TThis is a small score matrix, not the full N x N matrix.
Each row in S_block has its own running state:
row 1: m1, l1, o1
row 2: m2, l2, o2
row 3: m3, l3, o3
...The algorithmic shape is:
for each Q_block:
initialize m, l, o for rows in Q_block
for each K_block, V_block:
compute S_block = Q_block @ K_block^T
update m, l, o using online softmax
discard S_block
write final O_blockThe phrase “discard S_block“ is doing a lot of work.
It means the temporary scores are used while they are hot, then thrown away before they ever become a giant global tensor.
FlashAttention turns the attention matrix from an object into a stream.
This points to the underlying core principle:
The algorithm wins by collapsing the lifetime of the largest intermediate data, not by pretending the math is simpler than it is.
The IO-Aware Kernel
Now we can finally talk about the GPU story without drowning in it.
FlashAttention combines three ideas:
1. Tiling
2. Online softmax
3. Kernel fusionTiling means the computation is broken into blocks that fit in fast on-chip memory.
Online softmax means those blocks can be processed one at a time while still producing exact softmax attention.
Kernel fusion means the separate stages:
QK^T -> softmax -> PVare fused into one IO-aware kernel, so the intermediate score and probability blocks do not need to be written to HBM.
This is the right hierarchy of ideas:
online softmax makes block-wise exact attention possible
tiling keeps each block in fast memory
fusion prevents temporary blocks from becoming global memory trafficOrdinary attention is HBM-heavy:
HBM:
read Q, K
write S
read S
write P
read P, V
write OFlashAttention is IO-aware:
HBM:
read Q, K, V blocks
write final O
on-chip memory:
compute S_block
update m, l, o
accumulate output
discard S_blockThe speedup comes from reducing the amount of high-cost memory movement.
Not from changing what attention means.
The Backward Pass: Recomputation Over Retrieval
But this forward-pass optimization raises an immediate crisis for training:
If FlashAttention refuses to save the
N x Nprobability matrix, how does it compute gradients during backpropagation?
During the backward pass, standard backpropagation mathematically requires those exact N x N probabilities to compute the gradients for Q, K, and V. Usually, deep learning frameworks keep these huge intermediate tensors sitting in global memory (HBM) from the forward pass.
If we throw them away, how do we backpropagate?
The answer is one of the most elegant, counter-intuitive design choices in modern systems programming:
Recompute the attention blocks on-the-fly during the backward pass.
To a programmer trained in classic algorithmic optimization, this feels like sacrilege. Why spend compute cycles to re-evaluate the forward pass during the backward pass?
Because of the massive, growing imbalance in modern hardware: compute is incredibly cheap, while memory access is incredibly expensive.
We have entered a regime where:
Recomputing a block on-the-fly in fast SRAM is faster than reading a pre-computed block from slow HBM.
Instead of saving the giant N x N matrix, FlashAttention only saves the extremely compact per-row normalization statistics (m and l). During the backward pass, it loads the original blocks of Q, K, and V, uses the saved statistics to reconstruct the local score blocks on-the-fly, computes the gradients, and immediately discards the blocks again.
The backward pass follows the same spirit as the forward pass:
do not store the giant matrix
recreate small blocks on demand
use them immediately
discard them againThis is not a hack. It is a deliberate trade:
spend extra local computation
to avoid massive global memory trafficThat trade is often profitable because GPUs are built to do a lot of arithmetic, but moving giant tensors through memory remains expensive.
This is where FlashAttention becomes more than an attention optimization.
It becomes a lesson in modern performance engineering:
The fastest program is not always the one that does the fewest operations. It is often the one that moves the least data.
Backward recomputation is the strongest version of that lesson. It says that, on the right hardware, throwing something away and rebuilding it later can be cheaper than preserving it.
What FlashAttention Is Not
FlashAttention is easy to overstate.
It does not remove the N x N pairwise relationships in exact attention.
Each query still attends to keys. The algorithm still has quadratic structure in sequence length. If the sequence length doubles, the number of query-key interactions still grows quadratically.
So FlashAttention is not the same kind of idea as sparse attention, linear attention, or low-rank approximations. Those methods change the mathematical structure of attention to reduce the number of interactions or approximate the result.
FlashAttention does something narrower and, in some ways, more surprising:
It keeps exact attention, but changes how the intermediate data lives and dies.
That boundary matters.
If someone says FlashAttention makes attention linear, they have misunderstood it.
If someone says FlashAttention is just a CUDA trick, they have also missed the point.
It is a mathematical trick and a systems trick fitting together:
online softmax gives correctness
tiling gives locality
fusion gives low IO
recomputation gives memory-efficient trainingThe Mental Model to Keep
Here is the whole idea in one picture:
Ordinary attention:
Q, K, V
|
v
build full N x N scores
|
v
build full N x N probabilities
|
v
multiply by V
|
v
OFlashAttention:
Q, K, V
|
v
stream through score blocks
|
v
update m, l, o with online softmax
|
v
discard each block immediately
|
v
OThe giant matrix has changed status.
In ordinary attention, it is a resident object.
In FlashAttention, it is a transient event.
That is the intuition worth keeping.
The Core Insight
FlashAttention works because it separates two ideas that are easy to confuse:
The attention matrix must exist mathematically.
The attention matrix does not have to exist physically.Once that distinction is clear, the rest of the algorithm becomes inevitable.
If the matrix does not need to exist physically, we can compute it in blocks.
If we compute it in blocks, softmax becomes the obstacle.
If softmax is the obstacle, online softmax is the key.
If online softmax gives exact block-wise computation, tiling and fusion can keep the intermediate blocks in fast memory.
If backward needs the matrix again, we can recompute the blocks instead of reading a saved giant tensor.
The result is exact attention with a very different physical execution plan.
FlashAttention is not magic.
It is the moment you stop treating an intermediate matrix as a thing, and start treating it as a stream.
Once you see that, the algorithm stops feeling like a bag of GPU tricks. It becomes a disciplined answer to one question:
Which tensors deserve to live in memory, and which should only pass through?

