Resources
The content in this post is based on my notes from the following resources. I highly recommend exploring these directly for a deeper understanding.
- Understanding and Coding Self-Attention — Sebastian Raschka
- Illustrated Guide to Attention Mechanism — AI Summer
- Attention Is All You Need — Vaswani et al., 2017
- Neural Machine Translation by Aligning and Translating — Bahdanau et al., 2014
- The Illustrated Transformer — Lilian Weng, 2018
Overview
Starting with self-attention, we learn how an input is transformed to query, key, and value components using the respective weight matrices. The following procedure multiplies queries by keys to obtain attention scores, which are then normalized with the softmax function to get attention weights.
These weights are multiplied against the values matrix computed from the input, and the result is a context vector for the input. The computation of a single context vector is done with a single head, and this can be extended to multiple heads. The process of multi-head attention is simply computing context vectors independently for each head. The result is then concatenated and returned.
Finally, to ensure that the context vector does not contain predictions for future words within the input (i.e., each word only depends on preceding words), we can apply a mask. This is equivalent to zeroing out the upper-right triangle of the context matrix, with normalization occurring before or after to ensure each row sums to 1. Masking is particularly useful when generating output sequences to prevent future token leakage.
History
Sequence-to-sequence learning was originally dominated by encoder-decoder architectures linked by a single context vector. The encoder/decoder models were typically RNNs or LSTMs. These models struggled with long sequences due to issues like vanishing gradients and forgotten inputs.
A key limitation was the fixed-length context vector, which could fail to capture all relevant input information. Attention was introduced to directly connect each output to the entire input sequence, enabling the model to weight different parts of the input as needed.
Intuition
- Attention is about finding alignment between pieces of an input. Attention scores reflect how “aligned” two tokens are.
- A single token can align with multiple other tokens — attention allows for one-to-many relationships.
- As sequence length increases, attention’s computational cost grows quadratically.
- Connecting the context vector to the full input sequence is critical to ensure no information is lost in long sequences.