Week 6

Published

Wednesday, March 26, 2025

Autoregressive Models

Autoregressive (AR) models factorize a high-dimensional joint distribution into an ordered product of conditionals, enabling exact likelihood computation and principled density estimation.

General Principles

  • Factorization
    \[p(\mathbf{x}) = \prod_{i=1}^{D} p\bigl(x_i \mid x_{<i}\bigr),\quad x_{<i}=(x_1,\dots,x_{i-1}).\]
  • Naive Parameterization
    • A full tabular representation of each conditional \(p(x_i \mid x_{<i})\) requires \(O(K^D)\) parameters for discrete variables with alphabet size \(K\).
    • Alternatively, one could learn \(D\) independent classifiers—one per position—each with its own parameters, giving \(O(D\cdot P)\) parameters for classifier size \(P\).
  • Tractable Likelihood
    • Exact computation of \(p(\mathbf{x})\) permits maximum likelihood training and direct model comparison.
  • Data Modalities
    • Sequences (e.g. text): natural ordering
      \[p(\text{token}_t \mid \text{tokens}_{<t}).\]
    • Images: impose scan order (raster, zigzag, diagonal); choice of ordering affects receptive field and modeling capacity.
  • Pros & Cons
    • Exact density estimation → anomaly detection, likelihood-based evaluation
    • Sequential generation in \(O(D)\) steps; slow for high-dimensional \(D\)
    • Ordering choice for images is non-trivial
  • Comparison to Independence
    • Independent model: \(p(\mathbf{x})=\prod_i p(x_i)\) ignores structure but is fast.

Early AR Models for Binary Data

Fully Visible Sigmoid Belief Net (FVSBN)

  • Models \(\mathbf{x}\in\{0,1\}^D\) with fixed ordering.
  • Conditional logistic regression: \[ p(x_i=1 \mid x_{<i}) = \sigma\!\Bigl(\alpha_i + \sum_{j<i}W_{ij}\,x_j\Bigr). \]
  • Parameters: biases \(\{\alpha_i\}_{i=1}^D\) and weights \(W\in\mathbb{R}^{D\times D}\) (upper-triangular).
  • Complexity: \(O(D^2)\) parameters, simple but parameter-inefficient.

Neural Autoregressive Density Estimator (NADE)

  • Shares hidden-layer parameters across all conditionals to reduce complexity from \(O(D^2)\) to \(O(Dd)\), where \(d\) is hidden size.
  • Model
    1. Hidden activations \[ a_1 = c,\quad a_{i+1} = a_i + W_{:,\,i}\,x_i,\quad h_i = \sigma(a_i), \] where \(W\in\mathbb{R}^{d\times D}\) and \(c\in\mathbb{R}^d\) are shared.
    2. Conditional probability \[ p(x_i=1 \mid x_{<i}) = \hat{x}_i = \sigma\bigl(b_i + V_{i,:}\,h_i\bigr), \] with \(V\in\mathbb{R}^{D\times d}\) and \(b\in\mathbb{R}^D\).
  • Properties
    • Parameter count: \(O(Dd)\) vs. \(O(D^2)\) in FVSBN
    • Computation: joint log-likelihood \(\sum_i\log p(x_i\mid x_{<i})\) in \(O(Dd)\) via one forward pass
  • Training
    • Maximize average log-likelihood using teacher forcing (use ground-truth \(x_{<i}\)).
    • Extensions: RNADE (real-valued inputs), DeepNADE (deep MLP), ConvNADE.
  • Inference
    • Sequential sampling \(x_1,\dots,x_D\), with optional random ordering for binary pixels.

MADE (Masked Autoencoder for Distribution Estimation)

  • Adapts a feed-forward autoencoder to AR modeling by masking weights to enforce \(x_i\) depends only on \(x_{<i}\).
  • Masking
    • Assign each unit an index \(m\in\{1,\dots,D\}\).
    • A weight from unit \(u\to v\) is active only if \(m(u)<m(v)\) (hidden layers) and \(m(u)<i\) for output \(x_i\).
  • Training & Generation
    • Train by minimizing negative log-likelihood.
    • Sequential sampling: \(D\) forward passes, one per new \(x_i\).
  • Ordering Robustness
    • Randomizing masks/orders per epoch improves generalization.

AR Models for Images

PixelRNN

  • Uses RNNs (LSTM/GRU) over flattened pixels.
  • Dependency: Each pixel \(x_i\) conditioned on RNN hidden state summarizing \(x_{<i}\).
  • Ordering: Raster or diagonal scan; variants include Row LSTM, Diagonal BiLSTM.
  • Generation: \(O(HW)\) sequential steps for \(H\times W\) image.

PixelCNN

  • Uses masked convolutions to model \(p(x_{r,c}\mid x_{\prec(r,c)})\).
  • Mask Types
    1. Type A (first layer): blocks current pixel.
    2. Type B (subsequent): allows self-connection but no future pixels.
  • Training
    • Parallel over all pixels via masked convolution stacks.
  • Generation
    • Sequential sampling \(O(HW)\) steps.
  • Blind-Spot Issue
    • Deep stacks can leave unconditioned “holes.” Gated PixelCNN or two-pass designs address this.

TCNs (Temporal Convolutional Networks) & WaveNet

  • TCN: 1D causal convolutions ensure \(y_t\) depends only on \(x_{\le t}\).
  • Dilated Convolutions
    • Factor \(d\): gaps between filter taps.
    • Stack dilations \(1,2,4,\dots\) for exponential receptive field growth.
  • WaveNet
    • Masked, dilated convolutions on raw audio.
    • Models long-range audio dependencies efficiently.

Variational RNNs (VRNNs) & C-VRNNs

  • VRNN
    • Embeds a VAE at each timestep \(t\).
    • Prior: \(p(z_t\mid h_t)=\mathcal{N}(\mu_{0,t},\mathrm{diag}(\sigma_{0,t}^2))\) where \((\mu_{0,t},\sigma_{0,t})=\phi_{\text{prior}}(h_t)\).
  • C-VRNN
    • Conditions on external variables (e.g. control signals) alongside latent \(z_t\).

Architecture Trade-offs for \(p(x_i\mid x_{<i})\)

  • RNNs
    • Compress arbitrary history.
    • Recency bias; vanishing distant signals.
  • CNNs (Masked/Dilated)
    • Parallelizable; scalable receptive field.
    • Fixed context; blind spots.
  • Transformers
    • Masked self-attention to all \(x_{<i}\).
    • Highly parallel training.
    • Quadratic \(O(L^2)\) attention cost; need positional encoding.
  • Large Language Models (LLMs)
    • Transformer-based AR on token sequences.
    • Pretraining: next-token prediction.
    • Fine-Tuning: supervised tasks (SFT).
    • Tokenization: subword units (BPE, WordPiece).

AR for High-Resolution Images & Video

  • Challenge: Millions of pixels/voxels → \(O(D)\) sequential steps infeasible.
  • Approach: Compress data into discrete latent sequences via VQ-VAE.

Vector-Quantized VAE (VQ-VAE)

  • Objective: Learn a discrete latent representation for high-dimensional data, enabling compact encoding and efficient autoregressive modeling.

  • Model Components:

    • Encoder:
      • A convolutional neural network (CNN) mapping input
        \(\mathbf{x}\in\mathbb{R}^{H\times W\times C}\) to continuous latents
        \(\mathbf{z}_e(\mathbf{x})\in\mathbb{R}^{H'\times W'\times D}\).
      • Spatial down-sampling factor \(s\): \(H'=H/s,\;W'=W/s\) (e.g. \(s=8\)).
    • Codebook:
      • A collection \(\mathcal{E}=\{e_j\in\mathbb{R}^D\mid j=1,\dots,K\}\) of \(K\) trainable embedding vectors.
    • Quantization:
      • For each spatial cell \((p,q)\), take
        \(z_e(\mathbf{x})_{p,q}\in\mathbb{R}^D\) and find nearest codebook index \[ k_{p,q} = \arg\min_{j\in\{1..K\}} \bigl\|z_e(\mathbf{x})_{p,q} - e_j\bigr\|_2^2. \]
      • Replace \(z_e(\mathbf{x})_{p,q}\) with \(e_{k_{p,q}}\), yielding discrete latents
        \(\mathbf{z}_q(\mathbf{x})\in\mathbb{R}^{H'\times W'\times D}\).
    • Decoder:
      • A CNN mapping \(\mathbf{z}_q(\mathbf{x})\) back to a reconstruction
        \(\hat{\mathbf{x}}\in\mathbb{R}^{H\times W\times C}\).
  • Loss Terms:

    1. Reconstruction Loss: \[ L_{\mathrm{rec}} = \bigl\|\mathbf{x} - \hat{\mathbf{x}}\bigr\|_2^2. \]
    2. Codebook Loss: (fix encoder, update codebook) \[ L_{\mathrm{cb}} = \sum_{p,q} \bigl\|\mathrm{sg}\bigl[z_e(\mathbf{x})_{p,q}\bigr] - e_{k_{p,q}}\bigr\|_2^2. \]
    3. Commitment Loss: (fix codebook, update encoder) \[ L_{\mathrm{com}} = \beta \sum_{p,q} \bigl\|z_e(\mathbf{x})_{p,q} - \mathrm{sg}\bigl[e_{k_{p,q}}\bigr]\bigr\|_2^2, \] where \(\beta\) (e.g. \(0.25\)) balances encoder commitment, and \(\mathrm{sg}[\cdot]\) is stop‐gradient.
    • Total Loss: \[ L_{\mathrm{VQ\text{-}VAE}} = L_{\mathrm{rec}} + L_{\mathrm{cb}} + L_{\mathrm{com}}. \]
  • Optimization Details:

    • Straight-Through Estimator (STE):
      • Quantization is non-differentiable.
      • In backprop, pass gradients from \(\mathbf{z}_q\) directly to \(\mathbf{z}_e\) (i.e. treat quantization as identity).
    • Codebook Updates via EMA (alternatively):
      • Maintain per‐embedding counts \(N_j\) and sums \(m_j\): \[ N_j \gets \gamma\,N_j + (1-\gamma)\sum_{p,q}\mathbf{1}[k_{p,q}=j], \] \[ m_j \gets \gamma\,m_j + (1-\gamma)\sum_{p,q:k_{p,q}=j} z_e(\mathbf{x})_{p,q}, \] then update \[ e_j \gets \frac{m_j}{N_j}, \] with decay \(\gamma\approx0.99\).
  • Practical Considerations:

    • Hyperparameters:
      • Codebook size \(K\) (e.g. 512–1024), embedding dimension \(D\) (e.g. 64–256).
      • Compression factor \(s\) reduces sequence length from \(HW\) to \(H'W'\).
    • Downstream Use:
      • After training, the discrete codes \(\{k_{p,q}\}\) form a sequence of length \(H'W'\).
      • An autoregressive model (e.g. Transformer) can be trained on these code sequences:
        \[p(k_{p,q}\mid k_{<p,q}).\]
    • Intuition:
      • Each \(e_j\) becomes a “prototype” latent vector.
      • The encoder emits a continuous vector, which is snapped to its nearest prototype, enforcing discreteness and preventing posterior collapse.