A transformer walk-through, with Gemma
๐ Abstract
The article provides a detailed walkthrough of the inner workings of the Gemma 2B transformer language model, including explanations of key components such as tokenization, embedding lookup, transformer layers, attention, rotary positional encoding, multi-layer perceptron, and output projection. The author aims to demystify transformer-based large language models (LLMs) and provide intuition for why each step is necessary.
๐ Q&A
[01] Tokenization
1. What are the key characteristics of the tokenization process used in Gemma?
- Gemma uses subword tokenization, which can split long words into smaller tokens.
- The tokenizer includes whitespace and does not normalize capitalization, passing these details to the model to figure out.
- Gemma has a very large vocabulary size of 256,000 tokens, larger than many other models which typically use 30-60k vocabularies.
2. Why does Gemma use such a large vocabulary size?
- A larger vocabulary can help with coverage of the long tail of language, though there is a trade-off in efficiency between larger vocabularies and shorter sequences versus smaller vocabularies and longer sequences.
[02] Embedding Lookup
1. What is the purpose of the post-embedding rescaling step?
- The post-embedding rescaling step is necessary because the embedding parameters are shared between the input embedding lookup and the output embedding projection. The scaling factor of sqrt(hidden_size) ensures the appropriate scale for the inputs without interfering with the output projection.
2. How does the embedding similarity metric illustrate the model's understanding of language?
- The cosine similarity between embeddings of related words (e.g. colors, place names, adjectives) shows that the model has learned to associate similar words based on their contexts in the training data.
[03] Transformer Layer
1. What is the purpose of the residual layer design in Gemma?
- The residual layer design, where each layer adds its contribution to the existing representation rather than replacing it, helps prevent the hidden representation from growing exponentially and maintains a stable scale across multiple layers.
2. How does the RMSNorm help stabilize the hidden representation?
- RMSNorm divides out the root-mean-squared value of the hidden representation, resetting it to 1. This prevents the hidden representation from growing uncontrollably and causing numerical issues or saturation of nonlinearities.
[04] Attention
1. Explain the key steps in the self-attention mechanism used in Gemma.
- Gemma uses multi-query attention, where 8 independent attention operations share some parameters.
- The attention mechanism projects the input into query, key, and value vectors, computes dot product-based attention weights between queries and keys, and then takes a weighted sum of the values based on those attention weights.
- Causal masking is used to ensure each token can only attend to itself and past tokens, which is essential for autoregressive language modeling.
2. How does the rotary positional encoding (RoPE) allow the attention mechanism to capture relative position information?
- RoPE transforms the query and key vectors by rotating them based on their position in the sequence. This allows the attention weights to depend on the relative offset between tokens rather than just their absolute positions.
- RoPE uses a range of frequencies, from high frequencies that can distinguish individual tokens to low frequencies that can ignore position entirely, giving the model flexibility in how it uses positional information.
[05] Multi-Layer Perceptron (MLP)
1. Explain the purpose and intuition behind the Gaussian Error Linear Unit (GeGLU) used in Gemma's MLP.
- The GeGLU MLP can be thought of as a piecewise-quadratic function with smooth boundaries between the pieces, allowing it to learn more complex transformations of the input representation than a simple linear projection.
- The GELU nonlinearity provides a smoother transition between active and inactive regions of the function compared to a ReLU, which may be beneficial for optimization and the overall function learned by the MLP.
2. How does the MLP component complement the attention mechanism in Gemma?
- While attention is responsible for fusing information across the context, the MLP can independently transform each token representation, making attention more efficient at its core task of modeling relationships between tokens.
[06] Final Norm and Output Projection
1. Why is the final RMSNorm layer important before the output projection?
- The final RMSNorm helps prevent the hidden representation from becoming too large in scale, which could lead to over-confident and spiky output predictions.
2. How does the output projection layer share parameters with the initial embedding lookup?
- Sharing the same embedding parameters between the input and output projections is an efficient design choice, as it reduces the total number of parameters in the model.