Resources
The content presented in this blog is based on my notes from the following resources. I highly recommend exploring these articles and papers directly for a deeper understanding, especially if you're diving into self-attention and transformer-based architectures.
Overview
Starting with self-attention, we learn how an input is transformed to query, key, and value components using the respective weight matrices. The following procedure multiplies queries by keys to obtain attention scores, which are then normalized with the softmax function to get attention weights.
These weights are multiplied against the values matrix computed from the input, and the result is a context vector for the input. The computation of a single context vector is done with a single head, and this can be extended to multiple heads. The process of multi-head attention is simply computing context vectors independently for each head. The result is then concatenated and returned.
Finally, to ensure that the context vector does not contain predictions for future words within the input (i.e., each word only depends on preceding words), we can apply a mask. This is equivalent to zeroing out the upper-right triangle of the context matrix, with normalization occurring before or after to ensure each row sums to 1. Masking is particularly useful when generating output sequences to prevent future token leakage.
History
Sequence-to-sequence learning was originally dominated by encoder-decoder architectures linked by a single context vector. The encoder/decoder models were typically RNNs or LSTMs. These models struggled with long sequences due to issues like vanishing gradients and forgotten inputs.
A key limitation was the fixed-length context vector, which could fail to capture all relevant input information. Attention was introduced to directly connect each output to the entire input sequence, enabling the model to weight different parts of the input as needed.
Intuition
- Attention is about finding alignment between pieces of an input. Attention scores reflect how "aligned" two tokens are.
- A single token can align with multiple other tokens — attention allows for one-to-many relationships.
- As sequence length increases, attention's computational cost grows quadratically.
- Connecting the context vector to the full input sequence is critical to ensure no information is lost in long sequences.