Part 1 - Week 6

Ryan Cotterell
Published

Tuesday, March 18, 2025

Transformer-based Language Models

  • Transformers were introduced by Vaswani et al. (2017) as an architecture for sequence modeling, originally for machine translation, but widely used in language modeling (e.g., GPT models).
  • Key drawback of RNNs: They compress arbitrary-length sequences into a fixed-size hidden state vector, which can lead to information loss for long contexts.
  • Transformer solution: Retain contextualized representations for each token in the entire sequence, making the full history available without compression into a single vector.
  • Advantages:
    • Easily models long-range dependencies via attention mechanisms.
    • Allows parallel computation of representations, unlike sequential RNNs.
  • Abstract depiction: Transformers produce contextual embeddings for all symbols in a string, where each hidden state \(h_t\) attends to preceding symbols and the current one (see Figure 5.18 in source).
  • In generative framework: Symbols are generated sequentially, but representations are computed in parallel where possible (see Figure 5.19).

Formal Definition of Transformer Network

  • Definition: Transformer Network
    • A transformer network \(\mathcal{T}\) is a tuple \((\Sigma, D, \mathrm{enc}_{\mathcal{T}})\), where:
      • \(\Sigma\) is the alphabet (vocabulary) of input symbols.
      • \(D\) is the dimension of the embeddings.
      • \(\mathrm{enc}_{\mathcal{T}}\) is the transformer encoding function (detailed below).
  • Hidden State
    • The hidden state \(h_t \in \mathbb{R}^D\) after reading \(y_{\leq t}\) is \(h_t \overset{\mathrm{def}}{=} \mathrm{enc}_{\mathcal{T}}(y_{\leq t})\).
    • Unlike RNNs, \(h_t\) does not depend on previous hidden states directly.
  • Transformer Sequence Model (SM)
    • Given a transformer network \(\mathcal{T}\) and symbol representation matrix \(E \in \mathbb{R}^{|\bar{\Sigma}| \times D}\), a \(D\)-dimensional transformer SM over \(\Sigma\) is \((\Sigma, D, \mathrm{enc}_{\mathcal{T}}, E)\), defining: \[ p_{\mathrm{SM}}(\bar{y}_t \mid y_{<t}) \overset{\mathrm{def}}{=} \mathrm{softmax}(E h_{t-1})_{\bar{y}_t} = \mathrm{softmax}(E \mathrm{enc}_{\mathcal{T}}(y_{<t}))_{\bar{y}_t} \]
    • This fits into the representation-based locally normalized language modeling framework.

Attention Mechanism

  • Transformers compute contextual encodings using attention, avoiding sequential dependencies.
  • Definition: Attention
    • Let \(f: \mathbb{R}^D \times \mathbb{R}^D \to \mathbb{R}\) be a scoring function and \(f_{\Delta^{D-1}}\) a projection function.
    • For query \(q \in \mathbb{R}^D\), keys \(K_t = (k_1^\top, \dots, k_t^\top) \in \mathbb{R}^{t \times D}\), values \(V_t = (v_1^\top, \dots, v_t^\top) \in \mathbb{R}^{t \times D}\): \[ s_t = (s_1, \dots, s_t) \overset{\mathrm{def}}{=} f_{\Delta^{D-1}} \left( f(q, k_1), f(q, k_2), \dots, f(q, k_t) \right) \] \[ a_t = \mathrm{Att}(q, K_t, V_t) \overset{\mathrm{def}}{=} s_1 v_1 + s_2 v_2 + \cdots + s_t v_t \]
    • Query represents an individual symbol \(y_t\); \(K, V\) represent information from \(y_{<t}\).
    • Scoring function \(f\) measures relevance of keys to query; projection normalizes scores to sum to 1.
    • Result \(a\) is a convex combination of values, indexed by keys.
  • Common Scoring Function
    • Typically: \(f(q, k) = \frac{1}{\sqrt{D}} \langle q, k \rangle\) (scaled dot product).
  • Soft Attention
    • Uses \(f_{\Delta^{D-1}} = \mathrm{softmax}\).
  • Hard Attention (for Expressivity Analysis)
    • Averaging Hard Attention: \(f_{\Delta^{D-1}} = \mathrm{hardmax}_{\mathrm{avg}}\), where: \[ \mathrm{hardmax}_{\mathrm{avg}}(x)_d \overset{\mathrm{def}}{=} \begin{cases} \frac{1}{r} & \text{if } d \in \arg\max(x) \\ 0 & \text{otherwise} \end{cases} \] with \(r = |\arg\max(x)|\).
    • Unique Hard Attention: \(f_{\Delta^{D-1}} = \mathrm{hardmax}_{\mathrm{uni}}\), picks one argmax element (randomly or deterministically).
    • Difference: Averaging allows summarization (summing) over ties; unique picks only one, limiting expressivity (e.g., cannot sum relevant values if keys tie).

Transformer Layers and Blocks

  • Transformer Layer
    • Let \(Q, K, V, O: \mathbb{R}^D \to \mathbb{R}^D\) be parametrized functions.
    • A transformer layer \(T: \mathbb{R}^{T \times D} \to \mathbb{R}^{T \times D}\) takes input \(X = (x_1^\top, \dots, x_T^\top)\) and outputs \(Z = (z_1^\top, \dots, z_T^\top)\): \[ a_t = \mathrm{Att}(Q(x_t), K(X_t), V(X_t)) + x_t \quad (\text{residual connection}) \] \[ z_t = O(a_t) + a_t \] for \(t = 1, \dots, T\); \(T(X) \overset{\mathrm{def}}{=} Z\).
    • Residual connections (He et al., 2016) aid gradient flow and expressivity.
    • Typically, \(Q, K, V\) are linear; \(O\) is an MLP.
  • Full Transformer (L-Layer)
    • Initial: \(X^1 = (e'(y_0), e'(y_1), \dots, e'(y_t))\), where \(e'\) is position-augmented.
    • For layers \(\ell = 1\) to \(L\): \(Z^\ell = T^\ell(X^\ell)\), \(X^{\ell+1} = Z^\ell\).
    • Final: \(h_t = F(z_t^L)\), where \(F: \mathbb{R}^D \to \mathbb{R}^D\) is a transformation.
  • Attention Block (Matrix Form for Efficiency)
    • For input \(X \in \mathbb{R}^{T \times D}\): \[ A(X) = f_{\Delta^{D-1}} \left( Q(X) K(X)^\top \right) V(X) \]
    • Attention matrix: \(U = Q(X) K(X)^\top \in \mathbb{R}^{T \times T}\).
    • Self-attention: Same \(X\) for queries, keys, values.
  • Masked Attention Block (for Autoregressive Modeling)
    • Prevents looking ahead: \[ A(X, M) = \mathrm{softmax}(Q(X) K(X)^\top \odot M) V(X) \] where masking matrix \(M_{i,j} = \begin{cases} 1 & \text{if } i \geq j \\ -\infty & \text{otherwise} \end{cases}\) (for softmax).

Additional Components

  • Positional Encodings
    • Transformers are permutation-equivariant without them (output permutes if input does).
    • Definition: Positional encoding \(f_{\mathrm{pos}}: \mathbb{N} \to \mathbb{R}^D\).
    • Position-Augmented Representation: \(e'_{\mathrm{pos}}(y_t) = e'(y_t) + f_{\mathrm{pos}}(t)\) (or concatenation).
    • Essential for expressivity; without, cannot recognize simple languages like \(L = \{ ab^n \mid n \in \mathbb{N} \}\) (Pérez et al., 2021).
  • Multi-Head Attention
    • Computes multiple representations per symbol.
    • Definition: Multi-Head Attention Block
      • For \(H\) heads, functions \(Q_h, K_h, V_h\), and \(f_H: \mathbb{R}^{T \times H D} \to \mathbb{R}^{T \times D}\): \[ \mathrm{MH-A}(X) = f_H \left( \mathrm{concat}_{0 < h \leq H} \left( \mathrm{softmax}(Q_h(X) K_h(X)^\top) V_h(X) \right) \right) \]
    • Improves representation space; affects expressivity (e.g., simulates n-gram models easily).
  • Layer Normalization
    • Empirical trick for stable training (Ba et al., 2016).
    • Definition: For \(x, \gamma, \beta \in \mathbb{R}^D\), \(\epsilon > 0\): \[ \mathrm{LN}(x; \gamma, \beta) = \frac{x - \bar{x}}{\sqrt{\sigma^2(x) + \epsilon}} \odot \gamma + \beta \] (Often \(\gamma = 1\), \(\beta = 0\)).
    • Applied to layer outputs: \(z_t = \mathrm{LN}(O(a_t) + a_t; \gamma, \beta)\).
    • Improves expressivity (fixes limitations like in Hahn, 2020; Chiang and Cholak, 2022).

Tightness of Transformer-based Language Models

  • Theorem: Transformer Language Models are Tight
    • Representation-based locally normalized LMs defined by any fixed-depth transformer with soft attention are tight.
  • Proof Outline:
    • Key Lemma (Compactness Theorem): If \(X\) is compact and \(f: X \to Y\) is continuous, then \(f(X)\) is compact.
    • Compactness Lemma: For transformer function \(f_{\mathrm{Att}}\) with continuous \(Q, K, V, O\), and compact input set \(K \subseteq \mathbb{R}^D\), there exists compact \(K' \subseteq \mathbb{R}^D\) such that \(f_{\mathrm{Att}}(K^t) \subseteq (K')^t\) for all \(t > 0\).
      • Proof: Inputs are compact (word + position embeddings bounded). Each layer (attention + feedforward + residuals + norms) is continuous, preserving compactness (by theorem). Induction over 2L blocks yields output compactness.
    • Main Proof:
      • Inputs to first layer are compact (\(K\): finite vocab + bounded positions).
      • By lemma, outputs \(h_t\) in compact \(K'\) regardless of length.
      • Conditional probs via softmax: EOS prob is continuous function \(g_{\mathrm{eos}}: K' \to (0,1)\), image compact \(K'' \subseteq (0,1)\).
      • Infimum \(\delta = \inf K'' > 0\); by Proposition 2.5.6, model is tight.
  • Context: Tightness ensures the model defines a valid probability distribution over infinite sequences (from Chapter 2). This uses soft attention; hard attention may differ.

Representational Capacity of Transformer-based Language Models

  • Examining the expressivity of transformers requires careful consideration of the precise model definition and assumptions (e.g., type of attention, precision, positional encodings).
  • Small changes (e.g., soft vs. hard attention, unique vs. averaging hard) can lead to large differences in expressivity: Under some assumptions, transformers are Turing complete (Pérez et al., 2021); under others, they cannot recognize simple regular languages.
  • Theoretical results are more limited than for RNNs, as transformers are a newer architecture.
  • Analysis is complex due to the lack of a sequential hidden state passed between time steps (unlike RNNs or classical automata).
    • No straightforward mapping to established models of computation (e.g., FSAs, PDAs, Turing machines).
    • Transformers can “encode” state into the generated sequence (using it like a memory structure or stack), but this is not a perfect analogy and can be seen as “cheating” by augmenting the alphabet, leading to homomorphism equivalence rather than direct model equivalence.

Equivalence vs. Homomorphism Equivalence

  • Model Equivalence:
    • Two computational models \(C_1\) and \(C_2\) are equivalent if they recognize/generate the same language: \(L(C_1) = L(C_2)\).
  • Homomorphism Equivalence:
    • \(C_1\) is homomorphically equivalent to \(C_2\) if there exists a homomorphism \(h: L(C_1) \to L(C_2)\) such that \(h(L(C_1)) = L(C_2)\).
    • Symmetric: Requires both directions for full equivalence.
    • Allows mapping between languages via a transformation (e.g., augmenting strings with additional information like machine states).
    • In formal language theory, this is a distinct notion from direct equivalence (Culik and Salomaa, 1978).
  • Relevance to Transformers:
    • Transformers often achieve results via homomorphism equivalence by embedding computational state (e.g., Turing machine configuration) into the alphabet/output string.
    • This exploits the model’s ability to “look back” at the entire history but requires modified alphabets, differing from RNNs’ direct simulation (e.g., Turing completeness without augmentation in §5.2).

Inability of Transformers to Recognize Simple Languages

  • Despite success in modeling mildly context-sensitive human languages (Huybregts et al., 1984; Shieber, 1985), transformers (under certain assumptions) cannot recognize simple regular languages like First or Parity, or context-free languages like Dyck (Hahn, 2020; experimentally verified by Chiang and Cholak, 2022; Bhattamishra et al., 2020).
  • Example Languages: \[ \begin{aligned} \mathrm{First} &= \{ y \in \Sigma^* \mid \Sigma = \{0,1\}, y_1 = 1 \}, \\ \mathrm{Parity} &= \{ y \in \Sigma^* \mid \Sigma = \{0,1\}, y \text{ has odd number of 1s} \}, \\ \mathrm{Dyck} &= \{ y \in \Sigma^* \mid \Sigma = \{(,)\}, y \text{ is correctly parenthesized} \}. \end{aligned} \]
  • With Unique Hard Attention (Hahn, 2020):
    • Cannot recognize these; requires parameters to grow with input length.
    • Reason: Unique hard attention limits to selecting one value, preventing effective counting or tracking (e.g., parity of 1s or parenthesis balance).
  • With Soft Attention (Chiang and Cholak, 2022):
    • Theoretically possible, but confidence decreases with length (cross-entropy approaches 1 bit/symbol, worst-case).
    • Intuition: Membership depends on single symbols or counts, but soft attention averages over positions, diluting information as length grows.
  • Broader Issues:
    • Struggles with counting tasks (Bhattamishra et al., 2020).
    • Parallel nature limits sequential processing compared to RNNs (Merrill et al., 2022a; Merrill and Sabharwal, 2023).
    • Mitigated by components like layer normalization or averaging hard attention (fixes some limitations).

Transformers Can Simulate n-gram Models

  • Switching to averaging hard attention increases expressivity by allowing summation over tied keys.
  • Theorem: Transformer language models can simulate n-gram language models.
    • For any n-gram locally normalized model \(p_{\mathrm{LN}}\), there exists a transformer \(\mathcal{T}\) with \(L(p_{\mathrm{LN}}) = L(\mathcal{T})\) (i.e., they define the same distribution).
    • Equivalently: Transformers recognize strictly local (subregular) languages (§4.1.5).
  • Proof Outline (Intuitive Approach):
    • Use a single-layer transformer with \(H = n-1\) heads; each head attends to one specific position in the previous \(n-1\) symbols using positional encodings and a scoring function maximized at that position (e.g., via differences in queries/keys).
    • Heads extract one-hot encodings of those symbols; concatenate and apply a linear transformation + non-linearity (e.g., thresholded sigmoid) to encode the full n-gram (solving an “AND” problem).
    • Use the encoded n-gram to look up precomputed conditional probabilities in a matrix \(E\) (like a table), followed by softmax to match the n-gram distribution.
    • Intuition: Heads parallelize history extraction; combination identifies the local pattern; lookup simulates the n-gram conditional.
  • Significance:
    • This is the only “concrete” lower bound via model equivalence (not homomorphism).
    • Relies on multi-head attention and positional encodings; single-head without positions is weaker.
    • Construction is tedious but straightforward; extends to subregular languages.

Additional Notes

  • Overall Expressivity Ladder: Textbook ascends hierarchy (FSAs, PDAs, Turing machines) via unified constructions (building on Pérez et al., 2021), often using homomorphism equivalence.
  • Assumptions Matter: Infinite precision assumed; log-precision limits to constant-depth thresholds (Merrill and Sabharwal, 2023).

Tokenization: definitions and BPE (Byte Pair Encoding)

  • Definition

    • A tokenizer is a mapping \(\tau:\ \text{Unicode}^* \to V^*\) with an (approximate) inverse \(\tau^{-1}\). Goals: short sequences, coverage, robustness.
  • Why subwords

    • Balance word-level OOV vs. character-level sequence length.
  • BPE learning (frequency-merge)

    1. Base alphabet \(A_0\) (chars or bytes) and corpus \(C_0 \subset A_0^*\).
    2. At step \(t\), compute pair frequencies \(f_t(a,b)\) = count of adjacent \(ab\) in \(C_t\).
    3. Merge \((a^*, b^*) = \arg\max_{a,b} f_t(a,b)\) into new symbol \(c = a^* \circ b^*\); set \(A_{t+1} = A_t \cup \{c\}\) and replace all \(a^*b^*\) with \(c\) in \(C_t\) to obtain \(C_{t+1}\).
    4. Stop when \(|A_t| = |V|\) or after \(M\) merges.
    • Encoding applies learned merges greedily (longest-first). Byte-level BPE (base \(A_0\) = all 256 bytes) guarantees coverage and invertibility.
  • Contrasts

    • WordPiece: chooses merges to maximize corpus likelihood, not just frequency.
    • Unigram LM: learns token set \(U\) and picks a segmentation \(s^* = \arg\max_{s \in \mathrm{Seg}(x)} \sum_{u \in s} \log p(u)\).

Challenges with tokenization

  • Morphology and multilinguality: agglutinative languages fragment words → longer \(T\), higher compute; multilingual vocabularies can underfit non‑Latin scripts.
  • Numbers, code, rare strings: long numbers/URLs/identifiers expand to many tokens; harms numeric and symbolic tasks.
  • Unicode and normalization: inconsistent NFC/NFKC, combining marks, zero‑width chars, whitespace variants → brittle segmentations unless normalized.
  • Domain drift: merges learned on general text segment poorly in biomed/legal/ code domains.
  • Comparability: tokenized perplexity is not comparable across tokenizers. Prefer bits‑per‑byte/char, e.g., \(\mathrm{bpb} = -\frac{1}{N} \sum_{i=1}^N \log_2 p(b_i)\).
  • Safety/robustness: adversarial homoglyphs/zero‑widths can bypass rules even with byte coverage.

Generation: definitions and sampling adapters

  • Autoregressive LM

    • Given tokens \(y_{1:T}\), the model factorizes as: \[ p_\theta(y_{1:T}) = \prod_{t=1}^T p_\theta\!\big(y_t \mid y_{<t}\big),\quad p_t = \mathrm{softmax}(z_t),\quad p_{t,i} = \frac{e^{z_{t,i}}}{\sum_j e^{z_{t,j}}}. \]
    • At step \(t\), logits \(z_t \in \mathbb{R}^{|V|}\) are turned into a token by search or sampling.
  • Sampling vs. search

    • Search: greedy/beam (deterministic; can be bland/repetitive).
    • Sampling: add controlled randomness with adapters (below).
  • Sampling adapters (formal)

    • A sampling adapter is a function applied before selection: \((z_t', s_{t+1}) = A(z_t, s_t, y_{<t})\); then sample from \(\mathrm{softmax}(z_t')\).
    • It can reshape probabilities, restrict the support, and keep state \(s_t\).
  • Core adapters (with math)

    • Temperature: \(z_t' = z_t / \tau\) with \(\tau > 0\).
    • Top‑k: let \(S_k\) be indices of the \(k\) largest \(p_t\); define \(p'_i \propto p_{t,i}\,\mathbf{1}[i \in S_k]\).
    • Nucleus (top‑p): choose the smallest \(S \subset V\) with \(\sum_{i \in S} p_{t,i} \ge p\); set \(p'_i \propto p_{t,i}\,\mathbf{1}[i \in S]\).
    • Typical sampling: keep tokens whose surprisal \(s_i = -\log p_{t,i}\) is close to entropy \(H(p_t) = -\sum_i p_{t,i}\log p_{t,i}\); renormalize.
    • Repetition penalties: with counts \(c_i\), \(p'_i \propto p_{t,i}/r^{c_i}\) (frequency) or \(p'_i \propto p_{t,i}/r^{\mathbf{1}[c_i>0]}\) (presence), \(r \ge 1\).
    • No‑repeat \(n\)-gram: set \(p'_i = 0\) if choosing \(i\) would create a previously seen \(n\)-gram.
    • Grammar/regex constraints: restrict to tokens that keep a parser in a valid state; others get probability \(0\).
    • Mirostat (idea): adapt \(\tau\) online to target average surprisal \(\mu\), e.g., \(\tau \leftarrow \tau \exp(\eta(\hat{\mu}_t - \mu))\) with \(\hat{\mu}_t = (1-\alpha)\hat{\mu}_{t-1} + \alpha(-\log p_t(y_t))\).

What is a sampling adapter? Why do we need one?

  • What

    • A modular transformation on \(z_t\) (and on the admissible token set) before sampling: it implements \(z_t \mapsto z_t'\) and optional state/ constraints. Also called “logits processor/warper,” “logprob modifier,” or decoding constraint.
  • Why

    • Quality: mitigate repetition/degeneracy; balance coherence vs. diversity.
    • Control without retraining: impose style, safety, domain rules, and formats (e.g., valid JSON/SQL) at inference time.
    • Stability: maintain a target entropy/creativity across prompts/lengths.
    • Policy/safety: banlists, toxicity filters, task‑focused vocab biases.
    • Efficiency: fewer invalid outputs → less post‑processing.
  • Pragmatic defaults

    • Creative: top‑p \(\approx 0.9\), \(\tau \in [0.7, 1.0]\), repetition penalty \(r \in [1.05, 1.2]\).
    • Factual QA: lower \(\tau\) (e.g., \(0.2\text{–}0.7\)), top‑p \(0.8\text{–}0.95\), stop sequences and stronger repetition control.
    • Structured: grammar/regex constraints + low \(\tau\); optionally typical sampling for fluency.