All articles
Machine Learning

Masked multi-head attention in the transformer decoder: parallel training without future leakage

A detailed, beginner-friendly explanation of masked self-attention in the transformer decoder, why inference is autoregressive, how teacher forcing enables parallel training, and how the causal mask prevents cheating.

16 min read

The decoder problem that looks impossible at first

The transformer decoder has to satisfy two goals that seem to clash.

First, during inference, it must generate text autoregressively:

Text
predict one token
append it to the sequence
use that sequence to predict the next token
repeat

That is mandatory for language generation, translation, summarization, and next-token prediction. When the model is about to predict the next word, the future words do not exist yet. So it cannot look ahead.

Second, during training, we want speed. Waiting for the model to generate one token at a time would be painfully slow. GPUs are good at parallel computation, so we want to process many positions at once.

This creates the central decoder question:

Text
How can training be parallel while generation remains causally correct?

The answer is masked multi-head self-attention, often called causal self-attention.

It gives us the best of both worlds:

  • the decoder still learns the left-to-right rule of language generation
  • training can process all target positions in parallel
  • future tokens are hidden, so there is no data leakage

First, what "autoregressive" really means

An autoregressive model predicts the next token using only the tokens that came before it.

In probability form, a target sequence:

Text
y_1, y_2, y_3, ..., y_T

is modeled as:

Text
P(y_1, y_2, ..., y_T) = P(y_1) P(y_2 | y_1) P(y_3 | y_1, y_2) ... P(y_T | y_1, ..., y_{T-1})

That long equation says something simple:

Text
each new token depends only on the past

If the decoder is generating:

Text
The weather is

then it can try to predict:

Text
nice

But it cannot peek at the rest of the sentence first.

That left-to-right dependency is what makes the decoder autoregressive.

Where masked attention sits inside the decoder

A standard transformer decoder layer has three main parts:

  1. Masked multi-head self-attention
  2. Encoder-decoder cross-attention
  3. Feed-forward network

The first block is the one we care about here.

Why does the first block need masking?

Because it attends over the target sequence itself. If we let each target token look at all other target tokens during training, then position 2 could see position 3, position 4, and beyond. That would let the model cheat.

By contrast, the cross-attention block reads from the encoder output. That is allowed, because the source sentence is fully known.

So the mask belongs specifically to the decoder's self-attention over the target-side tokens.

Inference: why the decoder must generate one token at a time

Let us use a translation example:

Text
English source: How are you?
Nepali target: तपाईं कस्तो हुनुहुन्छ

During inference, the encoder reads the entire source sentence:

Text
How | are | you | ?

Then the decoder begins with a special start token:

Text
<START>

Now the process looks like this:

Step 1

Decoder input:

Text
<START>

Prediction:

Text
तपाईं

Step 2

Decoder input:

Text
<START> तपाईं

Prediction:

Text
कस्तो

Step 3

Decoder input:

Text
<START> तपाईं कस्तो

Prediction:

Text
हुनुहुन्छ

Step 4

Decoder input:

Text
<START> तपाईं कस्तो हुनुहुन्छ

Prediction:

Text
<END>

That is true autoregressive generation.

At every step, the model uses only the already generated target prefix plus the encoder's source representation.

There is no way around this during inference, because the future target tokens do not exist yet.

Why naive autoregressive training would be too slow

Suppose we trained exactly the same way as inference:

  1. feed <START>
  2. predict तपाईं
  3. append it
  4. predict कस्तो
  5. append it
  6. predict हुनुहुन्छ
  7. repeat

That would work, but it would waste the transformer's biggest strength:

Text
parallel computation across sequence positions

If we forced training to wait for one predicted token before computing the next, we would make the decoder feel much more like an old recurrent model during optimization.

That is why training usually uses teacher forcing.

Teacher forcing: the trick that makes training fast

Teacher forcing means we do not wait for the model to generate the previous correct token.

Instead, we feed the correct target tokens from the dataset.

For the target sequence:

Text
तपाईं कस्तो हुनुहुन्छ <END>

we build a shifted decoder input:

Text
<START> तपाईं कस्तो हुनुहुन्छ

and the expected outputs become:

Text
तपाईं कस्तो हुनुहुन्छ <END>

So position by position:

Decoder input tokenModel should predict
<START>तपाईं
तपाईंकस्तो
कस्तोहुनुहुन्छ
हुनुहुन्छ<END>

This is the key training trick.

We already know the full correct target sentence from the dataset, so we can prepare all decoder inputs at once. That means we can compute the logits for all positions in parallel.

But there is a subtle correction worth making here:

Text
training is parallelized, but the objective is still autoregressive

The model is still learning:

Text
predict the next token from the past

Teacher forcing only changes how efficiently we compute the training pass.

The cheating problem if we use ordinary self-attention

Now the danger appears.

Suppose we send the full decoder input:

Text
<START> तपाईं कस्तो हुनुहुन्छ

through a normal self-attention layer with no mask.

Self-attention compares every token with every other token.

That means:

  • <START> could attend to तपाईं, कस्तो, and हुनुहुन्छ
  • तपाईं could attend to कस्तो and हुनुहुन्छ
  • कस्तो could attend to हुनुहुन्छ

But that violates the autoregressive rule.

When the model is trying to predict कस्तो, it should only have access to:

Text
<START> तपाईं

not to:

Text
हुनुहुन्छ

If it can see future tokens during training, the loss becomes artificially easy. The model is effectively reading the answer sheet.

That is exactly what people mean by future-token leakage or data leakage in the decoder.

The fix: masked self-attention

Masked self-attention says:

Text
process all positions in parallel
but block attention to the future

So if the target length is 4, then:

  • position 1 can attend to position 1
  • position 2 can attend to positions 1 and 2
  • position 3 can attend to positions 1, 2, and 3
  • position 4 can attend to positions 1, 2, 3, and 4

but never the other way around.

This is why the mask is often drawn as a lower-triangular matrix.

A quick recap of attention scores

Inside self-attention, we compute:

Text
Attention(Q, K, V) = softmax((QK^T) / sqrt(d_k)) V

The matrix:

Text
QK^T

contains one score for every query-key pair.

If the sequence length is 4, the score matrix has shape:

Text
4 x 4

Before masking, it might look like this:

Text
[
  [2.1, 1.3, 0.7, 1.8],
  [0.5, 2.4, 1.9, 0.8],
  [1.1, 0.9, 2.7, 1.6],
  [0.4, 1.0, 1.5, 2.2]
]

Interpret row 2 like this:

Text
when token 2 queries the sequence,
how much should it attend to tokens 1, 2, 3, and 4?

Without masking, token 2 can see tokens 3 and 4, which is illegal for decoder self-attention.

How the mask is applied

We create a mask of the same shape:

Text
[
  [0,    -inf, -inf, -inf],
  [0,    0,    -inf, -inf],
  [0,    0,    0,    -inf],
  [0,    0,    0,    0   ]
]

Then we add it to the raw attention scores:

Text
masked_scores = scores + mask

So the previous score matrix becomes:

Text
[
  [2.1,  -inf, -inf, -inf],
  [0.5,  2.4,  -inf, -inf],
  [1.1,  0.9,  2.7,  -inf],
  [0.4,  1.0,  1.5,  2.2]
]

Now apply softmax row by row.

Remember what softmax does:

Text
softmax([a, b, c]) -> positive weights that sum to 1

But:

Text
e^(-inf) = 0

So any masked position gets probability zero after softmax.

That means row 2 becomes something like:

Text
softmax([0.5, 2.4, -inf, -inf]) ~= [0.13, 0.87, 0.00, 0.00]

The future has been hidden without breaking parallel computation.

That is the heart of masked self-attention.

What each position is allowed to know

For the decoder input:

Text
<START> तपाईं कस्तो हुनुहुन्छ

the mask enforces this visibility:

PositionCurrent tokenAllowed context
1<START><START>
2तपाईं<START>, तपाईं
3कस्तो<START>, तपाईं, कस्तो
4हुनुहुन्छ<START>, तपाईं, कस्तो, हुनुहुन्छ

So when the model produces logits at position 3, it can use:

Text
<START> तपाईं कस्तो

but not:

Text
हुनुहुन्छ

That keeps the training setup faithful to inference-time causality.

Why this still counts as parallel training

A common confusion is:

Text
If the decoder is autoregressive, how can training possibly be parallel?

Because all positions are computed in one matrix operation.

We do not run four separate forward passes for the four target positions. We run one forward pass over the shifted target sequence:

Text
<START> | तपाईं | कस्तो | हुनुहुन्छ

The mask makes each row behave as if it only had past context, even though the GPU computes all rows together.

So the training loop is:

  1. build the full shifted target sequence
  2. run masked self-attention over all positions together
  3. produce logits for all positions together
  4. compare against all next-token labels together

This is why transformers train so much faster than an architecture that must step through the sequence one token at a time.

Why it is called masked multi-head attention, not just masked attention

The decoder does not use one single attention map.

It uses multiple heads:

Text
head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)

Each head gets its own learned projections and its own masked attention calculation. The causal mask is applied inside every head.

That means each head can learn a different pattern while still respecting the no-future rule.

For example, different heads may specialize in:

  • local phrase structure
  • subject-verb agreement
  • long-range dependencies
  • punctuation or sentence boundaries
  • alignment patterns useful for translation

Then the head outputs are concatenated and projected back:

Text
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) W^O

So the full idea is:

Text
multiple attention heads
each head is causal
all heads run in parallel

Masked self-attention vs padding masks

Another easy confusion is mixing up two different masks.

1. Causal mask

This blocks future positions.

It enforces:

Text
token t cannot attend to tokens after t

2. Padding mask

This blocks padding tokens.

It enforces:

Text
do not treat <PAD> as real content

In real implementations, decoder attention often uses both ideas together:

  • a causal mask for left-to-right generation
  • a padding mask for ignoring padded positions

They solve different problems.

Why masking matches inference so well

During inference, future tokens are unavailable because they have not been generated yet.

During training, future tokens are physically present in the batch tensor because the full sentence is known.

The causal mask makes training behave as if those future tokens were not available.

That is the whole trick.

So the relationship is:

SettingAre future target tokens present in memory?Can the model use them?
InferenceNoNo
Training without maskYesYes, which is wrong
Training with causal maskYesNo, which is correct

This is why masked self-attention is not a minor implementation detail. It is what makes the decoder training objective faithful to the generation process.

A very practical mental model

Imagine four students taking a sequential exam.

Each student is supposed to answer only using:

  • the question in front of them
  • the notes from earlier solved questions

Training without a mask is like letting student 2 glance at student 3's answer sheet.

Training with a mask is like seating them behind one-way partitions:

  • student 1 sees only station 1
  • student 2 sees stations 1 and 2
  • student 3 sees stations 1, 2, and 3
  • student 4 sees all previous stations

All students can still work at the same time, but no one can copy from the future.

That is exactly what the causal mask does inside the decoder.

The part many beginners miss

People sometimes say:

Text
the decoder is autoregressive at inference but non-autoregressive at training

That sentence is useful informally, but it is slightly imprecise.

The better statement is:

Text
the decoder learns an autoregressive objective,
but training computes all positions in parallel using teacher forcing and a causal mask

Why does this distinction matter?

Because the model is never allowed to break the left-to-right rule. The rule stays intact. Only the computation becomes more efficient.

Decoder self-attention and cross-attention are doing different jobs

In the decoder, the two attention blocks play different roles.

Masked self-attention

This answers:

Text
What have I generated so far on the target side?

Cross-attention

This answers:

Text
Which source-side information from the encoder is relevant right now?

For translation, that means:

  • masked self-attention tracks the already generated Nepali prefix
  • cross-attention looks back at the English source representation

Both are necessary, but only the first one needs the causal mask.

Why the decoder would fail without this design

Without masked self-attention, training would reward the model for using information it will never have during inference.

That creates a mismatch:

  • training becomes artificially easy
  • evaluation during generation becomes much harder
  • the model learns dependencies it cannot rely on at test time

This kind of train-test mismatch is exactly what good architecture design tries to avoid.

Masked multi-head attention fixes it cleanly.

The shortest possible summary

If you want the entire idea in five lines, it is this:

  1. The decoder must generate text left to right.
  2. Inference is therefore inherently autoregressive.
  3. Training uses teacher forcing so all target positions can be computed together.
  4. A causal mask blocks every position from seeing future target tokens.
  5. After softmax, masked future positions get zero attention weight.

That is how transformers train fast without leaking the answer from the future.

Final takeaway

Masked multi-head attention is the mechanism that makes the transformer decoder practical.

It preserves the causal rule required for generation:

Text
use only the past to predict the next token

while still letting modern hardware process whole sequences efficiently during training.

So whenever you see the triangular decoder mask, remember what it is really doing:

Text
parallelizing the computation
without relaxing the logic of autoregressive generation

That is why the decoder can be fast in training, correct in inference, and faithful to the next-token prediction objective in both.