Foundations · 12 Stages

Deep Learning Systems.

From perceptron to production — backprop, optimizers, CNNs, Transformers, LoRA, quantization, and MLOps, built for engineers who ship models, not just train them.

PyTorchJAXHuggingFacePEFTbitsandbytesONNXTensorRTWandBtorchvision
Neural Net Foundations 01–02
01

Perceptron → MLP, Forward Pass & Activation Functions

You hire a new junior engineer and ask them to debug a ResNet that trains "but gives terrible accuracy." The first thing to check: are activations dead? Are gradients flowing? Understanding how a network actually computes — neuron by neuron, layer by layer — is the prerequisite for every debugging conversation you will ever have.

Perceptron → MLP & the Forward Pass
Neuron: z = wᵀx + b; MLP: stacked layers of linear + activation; forward pass left-to-right; Universal Approximation Theorem

A single neuron (perceptron) computes a weighted sum of inputs plus a bias, then applies a nonlinearity: z = w₁x₁ + w₂x₂ + ... + wₙxₙ + b = wᵀx + b output = σ(z) where σ is an activation function A Multi-Layer Perceptron (MLP) stacks layers of neurons: Layer 1: a⁽¹⁾ = σ(W⁽¹⁾x + b⁽¹⁾) Layer 2: a⁽²⁾ = σ(W⁽²⁾a⁽¹⁾ + b⁽²⁾) Output: ŷ = W⁽³⁾a⁽²⁾ + b⁽³⁾ The forward pass is purely left-to-right: compute each layer's pre-activation z, apply σ, pass to the next layer. No feedback, no cycles (in a basic MLP). Universal Approximation Theorem: A single hidden layer with enough neurons can approximate any continuous function on a compact subset of ℝⁿ to arbitrary precision. This gives MLPs their theoretical power — but says nothing about how to train them or how many neurons are "enough." In practice: depth is more efficient than width.

  Input         Hidden Layer 1        Hidden Layer 2       Output
  ─────         ──────────────        ──────────────       ──────
   x₁ ──────►  ┌────────────┐ ─────► ┌────────────┐ ────►
               │ σ(W⁽¹⁾x+b) │        │ σ(W⁽²⁾a+b) │
   x₂ ──────►  │  h₁  h₂   │ ─────► │  h₁  h₂   │ ────►  ŷ
               │  h₃  h₄   │        │  h₃  h₄   │
   x₃ ──────►  └────────────┘ ─────► └────────────┘ ────►

  z = Wx + b  →  a = σ(z)  →  z = Wa + b  →  a = σ(z)  →  ŷ
PyTorch — build and run an MLP forward pass, inspect layer shapes and activations
import torch
import torch.nn as nn
import torch.nn.functional as F

# ── Define a 3-layer MLP ──────────────────────────────────
class MLP(nn.Module):
    def __init__(self, in_features, hidden, out_features):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden)
        self.fc2 = nn.Linear(hidden, hidden)
        self.fc3 = nn.Linear(hidden, out_features)

    def forward(self, x):
        x = F.relu(self.fc1(x))   # hidden layer 1
        x = F.relu(self.fc2(x))   # hidden layer 2
        return self.fc3(x)         # output logits (no activation)

model = MLP(in_features=10, hidden=64, out_features=1)
n_params = sum(p.numel() for p in model.parameters())
print(f"Parameters: {n_params:,}")   # 10*64 + 64 + 64*64 + 64 + 64*1 + 1 = 4801

# ── Run a forward pass ────────────────────────────────────
batch = torch.randn(32, 10)   # 32 samples, 10 features
output = model(batch)
print(f"Input:  {batch.shape}   Output: {output.shape}")   # [32, 1]

# ── Inspect intermediate activations ─────────────────────
x = torch.randn(4, 10)
z1 = model.fc1(x)                          # pre-activation
a1 = F.relu(z1)                            # post-ReLU
z2 = model.fc2(a1)
a2 = F.relu(z2)
out = model.fc3(a2)

print(f"\nLayer-by-layer shapes and stats:")
print(f"  z1 pre-act:  {z1.shape}  range=[{z1.min():.2f}, {z1.max():.2f}]")
print(f"  a1 post-ReLU:{a1.shape}  zeros={( a1==0).sum()}/{a1.numel()}")
print(f"  output:      {out.shape}")

# ── Verify determinism in eval mode ──────────────────────
model.eval()
with torch.no_grad():
    out1 = model(x)
    out2 = model(x)
print(f"\nDeterministic in eval mode: {torch.allclose(out1, out2)}")
Trap Forgetting model.eval() during inference, causing non-deterministic outputs

In training mode, dropout randomly zeros activations and BatchNorm uses batch statistics rather than running statistics. A model in training mode gives a different output every call on the same input. Engineers often skip model.eval() during quick debugging or A/B comparisons, then wonder why outputs vary across runs.

Fix Standard inference pattern: model.eval(); with torch.no_grad(): output = model(x). Restore before training: model.train(). In production serving code, set eval mode once at startup — never toggle inside the serving loop. Use model.training (bool) to assert the mode is correct if needed.
Trap Treating raw nn.Linear output as probabilities

nn.Linear outputs raw logits — unbounded real numbers, often negative or > 1. Using them directly for thresholding (e.g., output > 0.5 for binary classification) or as probability distributions is incorrect. Logits do not sum to 1 in multi-class settings.

Fix Apply the correct output function at inference: torch.sigmoid(output) for binary; F.softmax(output, dim=-1) for multi-class. Never apply these before nn.BCEWithLogitsLoss or nn.CrossEntropyLoss — they do it internally and are more numerically stable.

The UAT states that a single hidden layer MLP with enough neurons can approximate any continuous function on a compact subset of ℝⁿ to arbitrary precision. Practical limitations: (1) Non-constructive — it does not say how many neurons are needed or how to find the right weights via training. (2) It says nothing about generalisation — a network can perfectly approximate training data and fail on unseen inputs. (3) Depth beats width in practice — a 3-layer network of width 256 generalises better and trains faster than a 1-layer network of width 100,000 with the same parameters. The theorem motivates using MLPs but does not tell you how to design one.

Composing linear functions produces a linear function. W₂(W₁x + b₁) + b₂ = (W₂W₁)x + (W₂b₁ + b₂) = W'x + b'. No matter how many linear layers you stack, the result is always expressible as a single matrix multiply plus bias — equivalent to one linear layer. The nonlinear activation breaks this: σ(W₂·σ(W₁x+b₁)+b₂) cannot be collapsed to a single linear operation. Without activations, a 100-layer MLP has exactly the same representational power as logistic regression. This is why accidentally removing activations in a custom model causes catastrophically poor performance while training appears to proceed normally (loss still decreases, just more slowly and toward a worse solution).

Activation Functions — ReLU, GELU, Sigmoid, Tanh
f(x)=max(0,x) vs smooth GELU; saturation and vanishing gradient; dead ReLU problem; when to use each

Activation functions introduce nonlinearity. The choice affects gradient flow, output range, sparsity, and training speed. ReLU — Rectified Linear Unit: f(x) = max(0, x) Gradient: 1 if x>0, 0 otherwise (non-differentiable at x=0) Fast, sparse activations, no saturation for x>0 Risk: dead ReLU — neurons permanently output 0 for all inputs GELU — Gaussian Error Linear Unit: f(x) = x · Φ(x) ≈ 0.5x(1 + tanh(√(2/π)(x + 0.044715x³))) Smooth soft-gate: attenuates rather than hard-zeros negative inputs Default in all modern Transformers: BERT, GPT, ViT, LLaMA Sigmoid: f(x) = 1/(1 + e^{-x}) Range: (0, 1) Gradient: f(x)(1-f(x)) ≤ 0.25 → saturates → vanishing gradient in hidden layers Only correct use: binary output layer Tanh: f(x) = (e^x − e^{-x})/(e^x + e^{-x}) Range: (-1, 1), zero-centred Gradient: 1 − f(x)² ≤ 1 → still saturates, but better than sigmoid for hidden layers Dead ReLU Problem: A neuron is dead if z < 0 for ALL inputs → gradient = 0 → weights never update Causes: large negative bias, large LR causing weights to go very negative Fix: Leaky ReLU f(x)=max(0.01x, x); ELU; careful He init + smaller LR

  Activation functions compared:

  f(x) ↑
   1.0 ├ · · · sigmoid · · · · · · (saturates at 1.0)
       │      · ·
   0.5 ├    · ·
       │  · ·                tanh (zero-centred, saturates ±1)
   0.0 ├──────────────────────────────────────→ x
       │                              ← GELU (slight negative, smooth)
  -0.5 ├

  ReLU:
  f(x) │          ╱  (gradient=1 for x>0)
       │         ╱
   0.0 ├────────╱────────────────────→ x
       │  gradient=0 here → dead ReLU risk
PyTorch — compare activations, check gradients at saturation, detect dead ReLU neurons
import torch
import torch.nn as nn
import torch.nn.functional as F

# ── Compare activation outputs ────────────────────────────
x = torch.linspace(-3, 3, 7)
print(f"x: {x.tolist()}")
print(f"ReLU:    {F.relu(x).tolist()}")
print(f"GELU:    {[round(v,3) for v in F.gelu(x).tolist()]}")
print(f"Sigmoid: {[round(v,3) for v in torch.sigmoid(x).tolist()]}")
print(f"Tanh:    {[round(v,3) for v in torch.tanh(x).tolist()]}")

# ── Gradient comparison — saturation ─────────────────────
x_sat = torch.tensor([-3.0, -1.0, 0.0, 1.0, 3.0], requires_grad=True)
torch.sigmoid(x_sat).sum().backward()
print(f"\nSigmoid grads at [-3,-1,0,1,3]: {[round(g,4) for g in x_sat.grad.tolist()]}")
# Near 0 at extremes → vanishing gradient in deep nets

x_relu = torch.tensor([-3.0, -1.0, 0.0, 1.0, 3.0], requires_grad=True)
F.relu(x_relu).sum().backward()
print(f"ReLU grads at [-3,-1,0,1,3]:    {x_relu.grad.tolist()}")
# Binary: 0 for x<0, 1 for x>0 — no saturation above 0

# ── Dead ReLU detection ───────────────────────────────────
class DeepReLU(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 64)
        self.fc2 = nn.Linear(64, 1)
    def forward(self, x):
        self.h1 = F.relu(self.fc1(x))
        return self.fc2(self.h1)

model = DeepReLU()
data  = torch.randn(200, 10)
_     = model(data)
dead  = (model.h1 == 0).all(dim=0).sum().item()
print(f"\nDead neurons (always 0 across 200 samples): {dead}/64")

# ── Transformer FFN block with GELU ─────────────────────
class TransformerFFN(nn.Module):
    def __init__(self, d_model=512, d_ff=2048):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
    def forward(self, x):
        return self.fc2(F.gelu(self.fc1(x)))   # GELU standard in Transformers

ffn = TransformerFFN()
x   = torch.randn(8, 64, 512)   # (batch, seq_len, d_model)
out = ffn(x)
print(f"\nTransformer FFN: {x.shape} → {out.shape}")
Trap Using sigmoid or tanh in hidden layers of a deep network

Sigmoid gradient max is 0.25 (at x=0). In a 10-layer sigmoid network, the chain-rule product of gradients is at most 0.25^10 ≈ 10^{-6}. Early layers receive near-zero gradients and effectively stop learning. The loss may still decrease (final layers learn) while the bulk of the network capacity goes unused. The bug is invisible without gradient-norm logging.

Fix Replace sigmoid/tanh in hidden layers with ReLU or GELU. Reserve sigmoid for binary output layers. If you need a bounded range in hidden layers, use tanh — it is at least zero-centred and has gradients up to 1.0. Log per-layer gradient norms to WandB to verify early layers are receiving meaningful signal.
Trap Not monitoring dead ReLU neurons during training

Dead neurons contribute nothing, waste capacity, and accumulate silently. A model with 40% dead neurons in layer 1 trains as if its first layer had 60% of its width. Symptoms: loss plateauing early, activation statistics going to zero, gradient norms collapsing for early layers — all subtle and easy to miss without explicit logging.

Fix After each forward pass (or every N steps), compute (activation == 0).float().mean(dim=0) per layer. If > 20% dead, check: lower the learning rate, switch to He init, use Leaky ReLU or ELU, add gradient clipping. WandB histograms of per-layer activations are the fastest way to spot this.

GELU is smooth and differentiable everywhere. ReLU has a hard zero for all x < 0 (dead zone) and a kink at x = 0. GELU computes x·Φ(x) where Φ is the standard normal CDF — this gives a soft gate that attenuates negative inputs by a smooth, input-dependent factor rather than zeroing them hard. Practical benefits: (1) No dead neuron problem — GELU never permanently zeros a neuron. (2) Smooth gradients throughout training improve Adam's convergence. (3) Empirically better on NLP benchmarks — BERT, GPT-2, GPT-3, ViT all default to GELU. ReLU remains preferred for CNNs where its sparsity is a useful inductive bias and its lower computation cost matters at scale.

A ReLU neuron is dead when its pre-activation z < 0 for every input in the training set. The ReLU gradient is exactly 0 for z < 0, so weights connected to this neuron receive zero gradient and never update. The neuron is stuck permanently. Causes: (1) Large initial learning rate pushing weights into a large negative region in the first few steps; (2) Very negative initial biases. Fix options without architecture change: (a) Switch to Leaky ReLU: f(x) = max(αx, x) with α=0.01 — gives a small gradient even for x<0, preventing permanent death; (b) Use He initialisation (kaiming_normal_ with fan_out) — scales initial weights so activation variance ≈ 1; (c) Reduce initial learning rate or add warmup. Architectural fix: skip connections (ResNet style) provide a gradient path around dead ReLU blocks.

An MLP without nonlinear activation is just a matrix multiply — no matter how many layers you stack, it collapses to a single linear transformation. The activation function is what gives deep networks their power.
02

Backpropagation, Chain Rule & torch.autograd

A production model suddenly produces NaN losses on step 500. The gradient debugging checklist starts here: which layer produced inf? Did backward() complete? Was retain_graph set incorrectly? Understanding backprop at the mechanics level — not just "optimizer.step()" — is what separates engineers who can debug training from those who restart and hope.

Backpropagation & the Chain Rule
Forward pass builds the computation graph; backward applies chain rule right-to-left; ∂L/∂w = ∂L/∂ŷ · ∂ŷ/∂z · ∂z/∂w

Backpropagation is the algorithm for computing ∂L/∂w for every parameter w. It applies the chain rule of calculus backwards through the computational graph. Forward pass — build the graph: x → [fc1] → z₁ → [ReLU] → a₁ → [fc2] → z₂ → [loss] → L Backward pass — chain rule, right to left: ∂L/∂z₂ = ∂L/∂L · 1 (start at loss) ∂L/∂W₂ = ∂L/∂z₂ · a₁ᵀ (gradient w.r.t. W₂) ∂L/∂a₁ = W₂ᵀ · ∂L/∂z₂ (propagate through fc2) ∂L/∂z₁ = ∂L/∂a₁ · 𝟙[z₁ > 0] (through ReLU) ∂L/∂W₁ = ∂L/∂z₁ · xᵀ (gradient w.r.t. W₁) Full chain rule in vector form: ∂L/∂W₁ = ∂L/∂ŷ · ∂ŷ/∂a₁ · ∂a₁/∂z₁ · ∂z₁/∂W₁ Key property: gradients multiply. If any term ≈ 0 (vanishing) or ≫ 1 (exploding), the entire chain is affected. This is why activation choice and skip connections matter.

  Computation graph: L = (z−y)²,  z = x·w + b

  x ──┐
      ├──[×]──→ xw ──┐
  w ──┘               ├──[+]──→ z ──┐
                  b ──┘              ├──[−]──→ (z−y) ──[²]──→ L
                                y ──┘

  Backward (chain rule, arrows reversed):
  ∂L/∂(z−y) = 2(z−y)
  ∂L/∂z     = 2(z−y) · 1        = 2(z−y)
  ∂L/∂w     = 2(z−y) · x
  ∂L/∂b     = 2(z−y) · 1        = 2(z−y)
  ∂L/∂x     = 2(z−y) · w        (not needed; x is input, not a parameter)
PyTorch — manual backprop, training loop anatomy, gradient norm logging
import torch
import torch.nn as nn

# ── Manual backprop on scalar computation ─────────────────
x = torch.tensor(2.0, requires_grad=True)
w = torch.tensor(3.0, requires_grad=True)
b = torch.tensor(0.5, requires_grad=True)
y = torch.tensor(10.0)

z = x * w + b              # forward: z = 6.5
L = (z - y) ** 2          # loss:    L = (6.5-10)^2 = 12.25
print(f"z={z.item():.2f}  L={L.item():.2f}")

L.backward()               # fill .grad for all requires_grad tensors

# Expected: ∂L/∂w = 2(z-y)·x = 2(-3.5)(2) = -14
print(f"\n∂L/∂w = {w.grad.item():.4f}  (expected: {2*(z.item()-y.item())*x.item():.4f})")
print(f"∂L/∂b = {b.grad.item():.4f}  (expected: {2*(z.item()-y.item()):.4f})")
print(f"∂L/∂x = {x.grad.item():.4f}  (expected: {2*(z.item()-y.item())*w.item():.4f})")

# ── Standard 5-step training loop ────────────────────────
model   = nn.Linear(4, 1)
optim   = torch.optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.MSELoss()
X = torch.randn(64, 4)
y = torch.randn(64, 1)

for step in range(5):
    optim.zero_grad()        # 1. clear accumulated gradients
    pred = model(X)          # 2. forward pass
    loss = loss_fn(pred, y)  # 3. compute scalar loss
    loss.backward()          # 4. backprop — fill .grad
    optim.step()             # 5. update weights
    print(f"Step {step}: loss={loss.item():.4f}")

# ── Gradient norm logging (production) ───────────────────
total_norm = 0.0
for name, p in model.named_parameters():
    if p.grad is not None:
        param_norm = p.grad.norm().item()
        total_norm += param_norm ** 2
        print(f"  {name}: grad_norm={param_norm:.6f}")
total_norm = total_norm ** 0.5
print(f"Total grad norm: {total_norm:.6f}")
Trap Forgetting optim.zero_grad(), causing gradient accumulation across batches

PyTorch accumulates gradients — every .backward() call ADDS to .grad rather than replacing it. Without zero_grad(), gradients from all previous batches accumulate and the effective update uses a sum of all past gradients scaled by the number of backward passes. Loss might initially appear to decrease faster (larger effective gradient) but then diverge. No error is raised.

Fix Call optim.zero_grad() at the START of every training step, before the forward pass. In PyTorch 2.x: optim.zero_grad(set_to_none=True) is slightly more memory-efficient (sets grad to None instead of zeroing the tensor). The only exception is intentional gradient accumulation — where you call zero_grad() every N steps, not every step.
Trap Calling .backward() on a non-scalar tensor

loss.backward() works when loss is a 0-dim scalar tensor. If you compute per-sample losses (shape [B,]) and call losses.backward(), PyTorch raises RuntimeError: "grad can be implicitly created only for scalar outputs." This is common when implementing custom per-sample loss weighting or focal loss.

Fix Always reduce to a scalar before backward: loss = losses.mean(); loss.backward(). For weighted per-sample loss: loss = (weights * losses).sum(); loss.backward(). Or pass an explicit gradient vector: losses.backward(torch.ones_like(losses)) — equivalent to summing.

The computational graph records every differentiable operation performed on tensors that require gradients. Each graph node stores: (1) the output tensor value from the forward pass, (2) a backward function (how to compute the gradient of the node's inputs given the gradient of its output). All intermediate activations needed for gradient computation are kept alive in the graph. PyTorch frees the graph after backward() by default because keeping it uses significant memory — proportional to the number of parameters times the activation sizes. If you need multiple backward passes (e.g., two separate loss terms, meta-learning), use loss.backward(retain_graph=True) which keeps the graph alive. Cost: memory doubles for each additional backward pass over the same graph.

A ResNet block computes h = F(x) + x — two paths contribute to the output. During the backward pass, gradients accumulate at the branching point. The gradient of the loss w.r.t. x is: ∂L/∂x = ∂L/∂h · (∂F(x)/∂x + I). The identity term I (from the skip connection) provides a lower bound: even if ∂F(x)/∂x → 0 (vanishing gradient through the residual branch), the gradient still flows back through the identity path with magnitude ∂L/∂h. This is why ResNets can be trained at 100+ layers — the skip connection is a guaranteed gradient highway. PyTorch handles this automatically: both paths are traced in the computational graph and gradients are summed at the merge point.

torch.autograd — detach(), retain_graph & Gradient Accumulation
Stopping gradient flow (detach vs no_grad); retain_graph for multi-backward; gradient accumulation for large effective batch

PyTorch's autograd engine builds and traverses the computational graph. Key patterns you will use in production code: 1. torch.no_grad() — context manager: Disables gradient tracking for all ops within the block. No graph is built → faster, less memory. Use for ALL inference and validation. 2. tensor.detach() — per-tensor stop: Creates a view sharing data but excluded from the graph. Use when ONE part of the forward should not receive gradients: • Target networks in RL (DQN): target should not update via backprop • GAN discriminator: pass fake.detach() so generator gradients don't flow • Contrastive learning momentum encoder (BYOL, MoCo) 3. retain_graph=True: Prevents the graph from being freed after backward(). Use for: computing higher-order gradients; multiple backward passes over the same graph. Cost: graph memory is not released → can OOM on large models. 4. Gradient accumulation: Simulate larger batch by splitting one batch into N micro-batches: for each micro-batch: loss = compute(micro_batch) / N ← scale to keep magnitude correct loss.backward() ← accumulate, don't step optimizer.step(); optimizer.zero_grad() ← single update on full "batch"

PyTorch — detach() for target network, retain_graph, gradient accumulation
import torch
import torch.nn as nn

# ── detach() — target network pattern (DQN) ──────────────
online_net = nn.Linear(8, 4)
target_net = nn.Linear(8, 4)
target_net.load_state_dict(online_net.state_dict())

state = torch.randn(32, 8)
online_q  = online_net(state)                       # gradients flow here

with torch.no_grad():
    target_q = target_net(state)                    # no graph built at all

# td_target must not receive gradients — detach makes it explicit
td_target = 0.99 * target_q.max(dim=1).values.unsqueeze(1).detach()
loss = nn.MSELoss()(online_q, td_target.expand_as(online_q))
loss.backward()

print(f"online_net grad norm: {online_net.weight.grad.norm():.4f}")
print(f"target_net grad:      {target_net.weight.grad}")   # None — correct

# ── retain_graph — higher-order gradient ─────────────────
x = torch.randn(4, requires_grad=True)
y = (x ** 2).sum()

# First backward with graph retained (create_graph=True to allow 2nd deriv)
grads = torch.autograd.grad(y, x, create_graph=True)   # ∂y/∂x = 2x
g = grads[0]
# Second derivative (Hessian diagonal)
hess = torch.autograd.grad(g.sum(), x)[0]              # ∂²y/∂x² = 2
print(f"\n∂y/∂x   = {grads[0].tolist()}")   # [2x₀, 2x₁, 2x₂, 2x₃]
print(f"∂²y/∂x² = {hess.tolist()}")           # [2, 2, 2, 2]

# ── Gradient accumulation ─────────────────────────────────
model   = nn.Linear(16, 1)
optim   = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

accum_steps = 4   # effective batch = 4 × 8 = 32
optim.zero_grad()

for i in range(accum_steps):
    X_mini = torch.randn(8, 16)
    y_mini = torch.randn(8, 1)
    loss = loss_fn(model(X_mini), y_mini) / accum_steps   # ← scale!
    loss.backward()   # accumulate; no optim step yet

# Single parameter update using accumulated gradients
optim.step()
optim.zero_grad()

grad_norm = sum(p.grad.norm()**2 for p in model.parameters()
               if p.grad is not None) ** 0.5
print(f"\nGradient norm after accumulation: {grad_norm:.4f}")
Trap Using tensor.data instead of tensor.detach() to stop gradient flow

tensor.data returns the underlying storage without a gradient function — similar to detach() in effect. However, .data bypasses autograd's in-place operation safety checks. If you modify a .data tensor in-place after a forward pass that tracked gradients, autograd may silently produce incorrect gradients for earlier operations — the error manifests as subtly wrong training behaviour, not an exception.

Fix Always use tensor.detach() for stopping gradient flow — it is autograd-safe and clearly communicates intent. Use tensor.detach().clone() if you need an independent copy. Reserve .data only for very specific low-level memory operations where you understand the implications and have verified correctness.
Trap Forgetting to divide loss by accumulation steps in gradient accumulation

Accumulating gradients over 4 mini-batches without scaling the loss by 1/4 gives a total gradient that is 4× larger than the true gradient for that effective batch size. This simulates a learning rate 4× too large — causing instability or divergence. The bug is subtle because the training loop runs without errors and loss decreases initially before diverging.

Fix Scale: loss = compute_loss(mini_batch) / accumulation_steps before every backward() call. Equivalently, scale gradients after accumulation but before the step. Verify: compare the weight update magnitude with accumulation vs without — they should be the same for equivalent batch sizes.

Use retain_graph=True when you need to call backward() more than once on the same computational graph — for example: (1) computing gradients of two separate loss terms independently (though usually you should just sum and backward once); (2) model-agnostic meta-learning (MAML) where the inner-loop gradient update must itself be differentiated; (3) truncated BPTT in RNNs; (4) computing higher-order gradients. Memory cost: the graph stores all intermediate activations needed for gradient computation — similar in size to the model's forward pass memory. With retain_graph=True, this memory is not freed. For large Transformers this can be hundreds of MB to GBs. Prefer: sum losses then backward once (no retain needed); use torch.autograd.grad(create_graph=True) for higher-order gradients.

.detach() operates on a single tensor — returns a new tensor sharing the same data but excluded from the computation graph. Selective: other tensors in the same computation can still track gradients. Use for: target networks in RL, contrastive momentum encoders. torch.no_grad() is a context manager that disables gradient tracking for ALL operations within the block — no graph is built, saving both compute and memory. Use for: validation loops, inference. @torch.no_grad() is the decorator form — apply to functions that should always run without gradients (inference methods, evaluation functions). Key difference: detach() is per-tensor and allows selective gradient stopping mid-graph; no_grad() is global within its scope and prevents any graph construction. You can use detach() inside a no_grad() block — it has no additional effect since no graph is being built anyway.

PyTorch's computational graph is built dynamically during the forward pass and destroyed after backward() by default. Every unusual training pattern — higher-order gradients, target networks, multi-loss training — requires understanding this lifecycle.
Training Mechanics 03–04
03

Loss Functions & Gradient Pathologies

Your object detection model has 99% background pixels and 1% object pixels, yet test mAP is terrible. Switching from cross-entropy to focal loss doubles mAP in two days. Loss function choice is not a detail — it encodes your statistical assumption about the output distribution and determines what "wrong" means for your model.

Loss Functions — MSE, Cross-Entropy, Focal, Triplet
Regression (MSE/MAE), classification (CE/BCE), imbalance (focal loss), metric learning (triplet); each as an MLE assumption

Loss function = how you define "wrong." Each corresponds to a probabilistic model. Regression: MSE: L = (1/n)Σ(yᵢ − ŷᵢ)² MLE under Gaussian noise N(ŷ, σ²). Squares large errors → outlier-sensitive. MAE: L = (1/n)Σ|yᵢ − ŷᵢ| MLE under Laplace noise. More robust; non-differentiable at 0. Huber: L = MSE if |y−ŷ|<δ, else MAE — best of both Classification: Cross-Entropy: L = −Σᵢ yᵢ log(ŷᵢ) (multi-class, expects logits) MLE under categorical distribution. nn.CrossEntropyLoss applies softmax internally. BCE: L = −[y log p + (1−y)log(1−p)] (binary) Use nn.BCEWithLogitsLoss — applies sigmoid internally, numerically stable. Imbalanced Classification: Focal Loss: FL(p) = −(1−p)^γ · log(p) (γ=2 typical) Downweights easy negatives (high p); focuses training on hard examples. Standard for dense object detection: RetinaNet, YOLO variants. Metric Learning: Contrastive: L = y·d² + (1−y)·max(0, m−d)² d = ||emb_a − emb_b|| Triplet: L = max(0, d(a,p) − d(a,n) + margin) Pulls same-class embeddings together, pushes different-class apart. Used in: face recognition (FaceNet), sentence similarity (SBERT), RecSys.

PyTorch — MSE vs MAE with outliers, cross-entropy pitfalls, focal loss, triplet loss
import torch
import torch.nn as nn
import torch.nn.functional as F

# ── MSE vs MAE — outlier sensitivity ─────────────────────
pred   = torch.tensor([1.0, 2.0, 3.0, 100.0])   # outlier at end
target = torch.tensor([1.0, 2.0, 3.0,   4.0])

print(f"MSE: {F.mse_loss(pred, target).item():.2f}")   # dominated by outlier
print(f"MAE: {F.l1_loss(pred, target).item():.2f}")    # robust

# ── Cross-entropy — correct usage with raw logits ─────────
logits = torch.randn(4, 10)           # batch=4, 10 classes
labels = torch.randint(0, 10, (4,))  # true class indices

ce = nn.CrossEntropyLoss()
loss_ce = ce(logits, labels)          # applies log-softmax internally
print(f"\nCE Loss: {loss_ce.item():.4f}")

# Verify: manual implementation
log_probs = F.log_softmax(logits, dim=-1)
manual_ce = -log_probs[range(4), labels].mean()
print(f"Manual CE: {manual_ce.item():.4f}  (should match)")

# ── BCE with pos_weight for class imbalance ───────────────
pos_weight = torch.tensor([9.0])       # 1 positive per 9 negatives
bce_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
logits_b = torch.randn(32)
labels_b = torch.zeros(32); labels_b[:3] = 1.0   # 3 pos of 32
loss_bce = bce_fn(logits_b, labels_b)
print(f"\nWeighted BCE: {loss_bce.item():.4f}")

# ── Focal loss (binary) ──────────────────────────────────
def focal_loss(logits, targets, gamma=2.0, alpha=0.25):
    bce  = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
    p    = torch.sigmoid(logits)
    p_t  = p * targets + (1-p) * (1-targets)      # p if y=1, else 1-p
    loss = bce * (1 - p_t) ** gamma
    a_t  = alpha * targets + (1-alpha) * (1-targets)
    return (a_t * loss).mean()

loss_fl = focal_loss(logits_b, labels_b)
print(f"Focal loss (gamma=2): {loss_fl.item():.4f}")

# ── Triplet loss for embeddings ───────────────────────────
triplet_fn = nn.TripletMarginLoss(margin=1.0)
anchor   = torch.randn(8, 128)
positive = anchor + 0.1 * torch.randn(8, 128)   # same class (slight noise)
negative = torch.randn(8, 128)                   # different class
loss_tri = triplet_fn(anchor, positive, negative)
print(f"\nTriplet loss: {loss_tri.item():.4f}")
# Should be near 0 — anchor closer to positive than negative
Trap Applying softmax before nn.CrossEntropyLoss — double-applying the activation

nn.CrossEntropyLoss and nn.BCEWithLogitsLoss apply their respective activations (log-softmax and sigmoid) internally. They expect raw logits. If you apply softmax first, you compute -log(softmax(softmax(logits))): softmax of softmax has a more uniform distribution, producing a systematically lower loss value with wrong gradients. The model still trains — just toward the wrong objective. The bug is silent.

Fix Pass raw logits to loss functions that have "WithLogits" in their name or documentation. Apply F.softmax(logits, dim=-1) only for inference to get probabilities, NEVER before loss computation. If you are confused, print the logits before and after the loss computation — they should be raw unbounded values.
Trap Using MSE for classification targets

MSE treats class indices as ordered reals: misclassifying class 0 as class 9 is penalised 81× more than class 0 as class 1 — but class ordering has no meaning. MSE also does not produce calibrated probabilities and gradients saturate on correct predictions. Models trained with MSE on classification converge slowly to poor accuracy.

Fix Use CrossEntropyLoss for multi-class, BCEWithLogitsLoss for binary, always. MSE is for continuous scalar outputs only. For ordinal targets (star ratings 1-5) where ordering matters, consider ordinal regression or coral loss, which encode the ordering constraint without the arbitrary scaling of MSE.

For a categorical distribution, the likelihood of observing label y given predicted class probabilities p is P(y|p) = pₙ where n is the true class. The log-likelihood is log P(y|p) = log pₙ. Negating and averaging over the dataset gives the cross-entropy: L = -(1/N)Σ log p_{yᵢ}. Minimising cross-entropy is exactly maximising the log-likelihood of the observed labels under the model's predicted distribution. The connection: MLE for a categorical distribution IS cross-entropy minimisation. This justifies cross-entropy as the principled choice for classification — it trains the model to produce calibrated class probabilities. MSE for classification corresponds to MLE under Gaussian noise, which is the wrong assumption for discrete labels.

Weighted BCE (pos_weight) scales the contribution of positive examples uniformly — all positive examples get the same weight boost regardless of how easy or hard they are. Focal loss scales by (1-p)^γ: easy examples (correctly classified with high confidence) contribute almost nothing to the loss, while hard examples dominate training. Choose focal loss when: (1) imbalance is severe (1:100+) and pos_weight alone is not sufficient; (2) you have many easy negatives — standard in object detection where most image regions are clearly background; (3) you want the model to improve on hard examples, not just maintain recall on easy positives. Choose weighted BCE when: imbalance is moderate (1:10 or less), or the problem is simpler (binary tabular classification). Start with pos_weight — it is simpler and often sufficient.

Vanishing & Exploding Gradients — Diagnosis and Fixes
Gradient product through depth; vanishing (sigmoid deep nets) vs exploding (NaN loss); skip connections, He init, gradient clipping

In a deep network, gradients are products of Jacobians across all layers. For L layers: ∂L/∂W₁ ≈ ∂L/∂aₗ · ∏ᵢ (Wᵢ · σ'(zᵢ)) Vanishing gradients — product < 1: Sigmoid max gradient = 0.25; for 10 layers: 0.25^10 ≈ 10^{-6} Early layers receive near-zero gradient → stop learning Symptom: loss decreases slowly; early-layer weights barely change; grad norms → 0 Exploding gradients — product > 1: Product grows exponentially → gradient norms spike to Inf or NaN Symptom: loss goes to NaN; weight norms diverge; training stops after a few steps Three structural fixes: 1. Better activations: ReLU gradient = 1 for x > 0 (no compression in positive region) ResNet skip: ∂(F(x)+x)/∂x = ∂F(x)/∂x + 1 ≥ 1 (identity gradient path) 2. Proper initialisation: He (Kaiming): Var(W) = 2/fan_in — designed for ReLU Xavier/Glorot: Var(W) = 2/(fan_in+fan_out) — designed for sigmoid/tanh Keeps activation variance ≈ 1.0 throughout the forward pass 3. Gradient clipping (for exploding): if ||g|| > max_norm: g ← g · (max_norm / ||g||) clip_grad_norm_(model.parameters(), max_norm=1.0) — standard for Transformers

  Gradient magnitude by layer (vanishing example, sigmoid deep net):

  Layer  │ Gradient Norm
  ───────┼───────────────────────────────────────
  Output │ ████████████████████  1.000   ← loss gradient
   L-1   │ █████████             0.490
   L-2   │ ████                  0.245
   L-3   │ ██                    0.122
   L-4   │ █                     0.061
   L-5   │ ·                     0.008   ← effectively vanished
   L-6   │                       0.001   ← dead
PyTorch — vanishing gradient demo, He init check, gradient clipping
import torch
import torch.nn as nn
import torch.nn.functional as F

# ── Vanishing gradient: sigmoid vs ReLU at depth 8 ───────
def make_net(activation, depth=8):
    layers = []
    for _ in range(depth):
        layers.extend([nn.Linear(32, 32), activation()])
    return nn.Sequential(*layers, nn.Linear(32, 1))

x = torch.randn(16, 32)

for act_name, act_cls in [('Sigmoid', nn.Sigmoid), ('ReLU', nn.ReLU)]:
    net = make_net(act_cls)
    loss = net(x).mean()
    loss.backward()
    # Collect gradient norms for weight matrices only
    gnorms = [p.grad.norm().item()
              for p in net.parameters() if len(p.shape) == 2]
    print(f"{act_name} gradient norms (early → late):")
    print(f"  {[f'{g:.6f}' for g in gnorms]}")

# ── Verify He init keeps activation variance ≈ 1 ─────────
model = nn.Sequential(
    nn.Linear(512, 512), nn.ReLU(),
    nn.Linear(512, 512), nn.ReLU(),
    nn.Linear(512, 512), nn.ReLU(),
)
# He init is the PyTorch default for Linear — verify
for i, layer in enumerate(model):
    if isinstance(layer, nn.Linear):
        nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
        nn.init.zeros_(layer.bias)

x_in = torch.randn(256, 512)
activations = []
h = x_in
for layer in model:
    h = layer(h)
    if isinstance(layer, nn.ReLU):
        activations.append(h.var().item())

print(f"\nActivation variance per ReLU layer (should stay near 1):")
for i, v in enumerate(activations):
    print(f"  layer {i+1}: {v:.4f}")

# ── Gradient clipping ─────────────────────────────────────
model_c = nn.Linear(32, 1)
optim   = torch.optim.SGD(model_c.parameters(), lr=0.1)
X_big   = torch.randn(4, 32) * 100.0    # large inputs → large gradients
y_big   = torch.randn(4, 1)

optim.zero_grad()
nn.MSELoss()(model_c(X_big), y_big).backward()

norm_before = nn.utils.clip_grad_norm_(model_c.parameters(), max_norm=float('inf'))
print(f"\nGrad norm BEFORE clipping: {norm_before:.2f}")

optim.zero_grad()
nn.MSELoss()(model_c(X_big), y_big).backward()
norm_after = nn.utils.clip_grad_norm_(model_c.parameters(), max_norm=1.0)
print(f"Grad norm AFTER  clipping (max=1.0): {min(norm_after, 1.0):.4f}")
optim.step()   # step AFTER clipping
Trap Calling clip_grad_norm_ after optimizer.step() instead of before

clip_grad_norm_ must be called AFTER loss.backward() (which fills .grad) and BEFORE optim.step() (which uses .grad to update weights). Calling it after step() clips gradients that were already used for the update — it has no effect on the current step. This is the most common mistake when adding gradient clipping to an existing training loop.

Fix The mandatory order: (1) optim.zero_grad(), (2) forward + loss, (3) loss.backward(), (4) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm), (5) optim.step(). Clip between backward and step. The clip_grad_norm_ function returns the total gradient norm before clipping — log this value.
Trap Using default random initialisation (torch.randn * 0.01) for deep custom networks

torch.randn(out, in) * 0.01 initialises weights with variance 0.0001. For ReLU networks, activation variance halves each layer: 0.0001^10 ≈ 10^{-40} at layer 10. Gradients vanish immediately even with ReLU. This is common in code that initialises embedding projections or custom linear layers manually.

Fix Use nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu') for ReLU networks; nn.init.xavier_uniform_(w) for tanh/sigmoid. Verify: after the first forward pass with random data, check activation variance at each layer — it should stay near 1.0 throughout the network.

A standard layer computes h = F(x). The gradient through it is ∂L/∂x = ∂L/∂h · ∂F(x)/∂x. If ∂F(x)/∂x is small (vanishing), ∂L/∂x ≈ 0 and early layers stop updating. A ResNet block computes h = F(x) + x. The gradient is: ∂L/∂x = ∂L/∂h · (∂F(x)/∂x + I). The identity term I guarantees a minimum gradient magnitude — even if ∂F(x)/∂x → 0, gradients still flow through the +I path with magnitude ∂L/∂h. This is why ResNet-152 trains reliably while a plain 152-layer network fails completely — each block provides a guaranteed gradient highway regardless of how the residual branch behaves.

Exploding gradients: (1) loss spikes to NaN or Inf after a few steps — often step 1-50; (2) gradient norms jump to hundreds or thousands in a single step; (3) torch.autograd.detect_anomaly() shows the exact operation that produced NaN; (4) weight norms grow rapidly. Fix: add gradient clipping immediately (clip_grad_norm_ with max_norm=1.0). Vanishing gradients: (1) loss decreases very slowly or plateaus after moderate progress; (2) gradient norms for early layers are near 0 while later layers are normal — must log per-layer norms to see this; (3) early layer weights barely change across many epochs (compare checkpoints); (4) activation values in early layers cluster near 0. Fix: architecture changes (skip connections, normalisation layers, better activations) — clipping does not help vanishing gradients.

Every loss function corresponds to a statistical assumption about the relationship between inputs and outputs. Cross-entropy is MLE for categorical distributions; MSE is MLE for Gaussian noise. Choosing the wrong loss means optimising the wrong assumption — your model will converge, just not to what you want.
04

Optimizers & Learning Rate Schedules

You fine-tune a pre-trained BERT model with vanilla Adam at lr=0.01. Loss diverges immediately. You try lr=1e-5 — it trains, but takes 10× longer than expected. Understanding why AdamW with warmup + cosine decay is the standard Transformer recipe — not arbitrary — is what makes you able to design training runs rather than copy-paste configs.

Optimizers — SGD+Momentum, Adam, AdamW, Lion
Momentum smooths oscillations; Adam adapts per-parameter LR; AdamW decouples weight decay; Lion uses sign-based updates

SGD with Momentum: v_t = β·v_{t-1} + g_t (velocity accumulates gradient) w_t = w_{t-1} - η·v_t β = 0.9 typical. Momentum smooths oscillations, accelerates convergence. Requires careful LR tuning; good final performance when tuned. Adam (Adaptive Moment Estimation): m_t = β₁·m_{t-1} + (1-β₁)·g_t (1st moment: mean of gradient) v_t = β₂·v_{t-1} + (1-β₂)·g_t² (2nd moment: uncentred variance) m̂_t = m_t / (1-β₁ᵗ) (bias correction — critical early in training) v̂_t = v_t / (1-β₂ᵗ) w_t = w_{t-1} - η · m̂_t / (√v̂_t + ε) Defaults: β₁=0.9, β₂=0.999, ε=1e-8, η=1e-3 AdamW (Decoupled Weight Decay — Loshchilov & Hutter 2019): Same as Adam, but weight decay applied directly to parameters: w_t = w_{t-1} - η · m̂_t/(√v̂_t+ε) - η·λ·w_{t-1} In Adam, weight decay is entangled with adaptive scaling → inconsistent regularisation. AdamW decouples it → every parameter receives the same relative weight decay. Standard for fine-tuning Transformers (BERT, GPT, ViT). Lion (EvoLved Sign Momentum — Google Brain 2023): update = sign(β₁·m + (1-β₁)·g) (sign-based: uniform magnitude) m = β₁·m + (1-β₁)·g More memory-efficient than Adam (only 1 momentum state vs 2). Competitive at large scale; needs ~3-10× smaller LR than Adam.

  Training loss — SGD+momentum vs Adam vs AdamW (same task, same steps):

  Loss ↑
  1.00 │╲  SGD+momentum (lr=0.01, β=0.9)
       │ ╲
       │  ╲___
  0.50 │      ╲________
       │               ╲__________________  ≈0.35 at 1000 steps
       │╲  Adam (lr=1e-3, β₁=0.9, β₂=0.999)
       │ ╲__
  0.20 │     ╲_____
       │           ╲_____________________  ≈0.12 at 1000 steps
       │╲  AdamW (lr=1e-3, wd=0.01)
       │ ╲__
  0.10 │     ╲_____
       │           ╲_____________________  ≈0.09 (WD reduces overfitting)
       └──────────────────────────────────────────→ Steps
         0       250       500       750      1000

  SGD: slower convergence — needs careful LR tuning and a schedule.
  Adam: fast convergence — adaptive per-parameter LR.
  AdamW: same speed as Adam, lower final loss — decoupled weight decay.
PyTorch — compare optimizers, AdamW parameter groups for Transformers, inspect Adam state
import torch
import torch.nn as nn
import torch.optim as optim

# ── Compare optimizers on simple regression ───────────────
torch.manual_seed(42)
X = torch.randn(256, 32)
y = X @ torch.randn(32, 1) + 0.1 * torch.randn(256, 1)

def run(opt_name, steps=200):
    model = nn.Linear(32, 1)
    if opt_name == 'SGD':
        opt = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    elif opt_name == 'Adam':
        opt = optim.Adam(model.parameters(), lr=1e-3)
    elif opt_name == 'AdamW':
        opt = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
    for _ in range(steps):
        opt.zero_grad()
        nn.MSELoss()(model(X), y).backward()
        opt.step()
    return nn.MSELoss()(model(X), y).item()

for opt in ['SGD', 'Adam', 'AdamW']:
    print(f"{opt:6s}: final loss = {run(opt):.6f}")

# ── AdamW parameter groups — exclude bias & LayerNorm ─────
encoder = nn.TransformerEncoderLayer(d_model=256, nhead=8, batch_first=True)
no_decay = ['bias', 'LayerNorm.weight']

param_groups = [
    {
        'params': [p for n, p in encoder.named_parameters()
                   if not any(nd in n for nd in no_decay)],
        'weight_decay': 0.01,
    },
    {
        'params': [p for n, p in encoder.named_parameters()
                   if any(nd in n for nd in no_decay)],
        'weight_decay': 0.0,   # bias & LayerNorm: no weight decay
    },
]
optimizer = optim.AdamW(param_groups, lr=2e-5)
print(f"\nAdamW groups:")
print(f"  With WD (0.01): {len(optimizer.param_groups[0]['params'])} params")
print(f"  No WD:          {len(optimizer.param_groups[1]['params'])} params")

# ── Inspect Adam internal state after one step ────────────
model2 = nn.Linear(4, 2)
opt2   = optim.Adam(model2.parameters(), lr=1e-3)
nn.MSELoss()(model2(torch.randn(8, 4)), torch.randn(8, 2)).backward()
opt2.zero_grad()
nn.MSELoss()(model2(torch.randn(8, 4)), torch.randn(8, 2)).backward()
opt2.step()

state = opt2.state[model2.weight]
print(f"\nAdam state keys: {list(state.keys())}")
print(f"  step:              {state['step']}")
print(f"  exp_avg norm:      {state['exp_avg'].norm():.6f}   (m̂: 1st moment)")
print(f"  exp_avg_sq norm:   {state['exp_avg_sq'].norm():.6f}  (v̂: 2nd moment)")
Trap Applying weight decay to bias and LayerNorm parameters

L2 weight decay penalises large parameter values. For bias terms, this adds an unnecessary shrinkage force — biases control output shifts, not representation scale. For LayerNorm gain/bias, weight decay interferes with the normalisation mechanism. Applying uniform weight decay degrades fine-tuning performance and requires manually lowering lambda to compensate.

Fix Use two parameter groups: weights with weight_decay=0.01; biases and LayerNorm parameters with weight_decay=0.0. Filter using: any(nd in name for nd in ['bias', 'LayerNorm.weight']). This is the standard pattern in HuggingFace's get_optimizer_grouped_parameters().
Trap Using SGD learning rates with Adam or Adam learning rates with SGD

Adam normalises gradients by their running variance estimate — effective step size is approximately η regardless of gradient magnitude. Adam works well at lr=1e-3 to 1e-4. SGD without adaptivity needs lr=0.01 to 0.1 with momentum to converge in similar time. Using SGD's lr (0.01) with Adam causes very slow convergence (Adam step much smaller than gradient; lr is further scaled down). Using Adam's lr (1e-3) with SGD causes divergence or slow convergence.

Fix Default starting LRs: Adam/AdamW = 1e-3 (from scratch), 1e-5 to 5e-5 (fine-tuning LLMs), 3e-4 (RL). SGD+momentum = 0.01 to 0.1. Always sweep LR when changing optimizer. Run an LR range test (fastai-style) to find the maximum stable LR.

In Adam, weight decay is implemented as L2 regularisation by adding λ·w to the gradient before the adaptive update: g_effective = g + λ·w. This weight decay term is then scaled by 1/(√v̂+ε), which varies per parameter. Parameters with large gradient variance receive smaller effective weight decay — regularisation is inconsistent and inversely proportional to gradient variance. AdamW decouples weight decay from the adaptive step: w_t = (1-η·λ)·w_{t-1} - η·m̂/(√v̂+ε). Every parameter receives the same relative weight decay regardless of its gradient history. For Transformers: Adam's inconsistent weight decay leads to worse generalisation on downstream tasks. The BERT paper showed that AdamW outperforms Adam by 1-2% on GLUE. For fine-tuning, the difference is especially noticeable because weight decay controls how much pre-trained weights are pulled toward zero — consistent application across all layers matters.

Adam initialises m₀ = 0 (1st moment) and v₀ = 0 (2nd moment) — both are zero. In the first step with gradient g₁: m₁ = (1-β₁)·g₁ = 0.1·g₁ — much smaller than the true gradient mean. Without correction, the effective step size is tiny early in training. Bias correction: m̂_t = m_t/(1-β₁ᵗ). At t=1: m̂₁ = 0.1·g₁ / (1-0.9) = g₁ — the correct unbiased estimate. At t=100: (1-0.9^100) ≈ 1, so correction has negligible effect. The correction matters most in the first 10-50 steps. Without it, fine-tuning runs make very small initial updates — the model barely moves from initialisation, especially problematic when fine-tuning from a pre-trained checkpoint where you want the first updates to be meaningful.

LR Schedules — Warmup, Cosine Decay & 1-Cycle Policy
Linear warmup prevents early instability; cosine annealing for smooth decay; 1-cycle policy; weight decay with AdamW

The learning rate schedule is nearly as important as the optimizer. Linear Warmup: η increases from 0 → η_max over warmup_steps Critical for Transformers: moment estimates (m̂, v̂) are unstable in first steps. Without warmup: early large updates take model to bad local region. Standard: 1-6% of total training steps as warmup. Cosine Annealing: η_t = η_min + 0.5(η_max − η_min)(1 + cos(π·t/T)) Smoothly decays from η_max to near 0. Avoids sharp minima: models converge to flatter, more generalisable regions. Industry default for most modern training runs. Warmup + Cosine (standard for LLMs): Steps 0 → warmup: linear ramp 0 → η_max Steps warmup → T: cosine decay η_max → η_min LLM training recipe: warmup 2000 steps, cosine to 10% of max LR. 1-Cycle Policy (Leslie Smith 2018): Phase 1 (30-45%): LR ramp η_min → η_max Phase 2 (45-55%): LR decay η_max → η_min Phase 3 (10%): very low LR (η_min / 100) Momentum cycled inversely: high momentum when LR is low. Super-convergence: reaches good solutions in 5-10× fewer epochs.

PyTorch — warmup+cosine schedule, OneCycleLR, verifying LR at each step
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import (
    CosineAnnealingLR, OneCycleLR, LinearLR, SequentialLR
)

model = nn.Linear(32, 1)

# ── Warmup + Cosine decay ─────────────────────────────────
optimizer = optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
warmup_steps = 100
total_steps  = 1000

warmup = LinearLR(
    optimizer,
    start_factor=1e-7 / 2e-5,   # start near 0
    end_factor=1.0,
    total_iters=warmup_steps
)
cosine = CosineAnnealingLR(
    optimizer,
    T_max=total_steps - warmup_steps,
    eta_min=0.0
)
scheduler = SequentialLR(optimizer,
    schedulers=[warmup, cosine],
    milestones=[warmup_steps]
)

lrs = []
for step in range(total_steps):
    lrs.append(optimizer.param_groups[0]['lr'])
    scheduler.step()

print("Warmup + Cosine LR:")
for step in [0, 50, 100, 300, 500, 999]:
    print(f"  step {step:4d}: lr={lrs[step]:.8f}")

# ── HuggingFace-style get_scheduler ──────────────────────
try:
    from transformers import get_cosine_schedule_with_warmup
    optimizer2 = optim.AdamW(model.parameters(), lr=2e-5)
    sched2 = get_cosine_schedule_with_warmup(
        optimizer2,
        num_warmup_steps=100,
        num_training_steps=1000
    )
    print("\nHuggingFace cosine schedule loaded successfully")
except ImportError:
    print("\n(transformers not installed; using manual scheduler above)")

# ── 1-Cycle LR with SGD ───────────────────────────────────
model3 = nn.Linear(32, 1)
opt3   = optim.SGD(model3.parameters(), lr=0.1, momentum=0.9)
sched3 = OneCycleLR(
    opt3,
    max_lr=0.1,
    steps_per_epoch=100,
    epochs=5,
    pct_start=0.3,       # 30% warmup phase
    anneal_strategy='cos'
)

lrs_1c = []
for _ in range(100):   # one epoch
    lrs_1c.append(opt3.param_groups[0]['lr'])
    opt3.zero_grad()
    nn.MSELoss()(model3(torch.randn(4,32)), torch.randn(4,1)).backward()
    opt3.step()
    sched3.step()        # ← call after every optimizer.step() for OneCycleLR

print(f"\n1-Cycle LR (one epoch):")
print(f"  Start: {lrs_1c[0]:.6f}")
print(f"  Peak:  {max(lrs_1c):.6f}")
print(f"  End:   {lrs_1c[-1]:.6f}")
Trap Calling scheduler.step() at the wrong frequency — per epoch for step-based schedulers

OneCycleLR and warmup schedulers expect scheduler.step() after every optimizer.step() (inside the batch loop). Calling them once per epoch makes warmup take 100 epochs instead of 100 steps — the model trains at near-zero LR for almost the entire run. The inverse: calling epoch-based schedulers (StepLR) per step decays LR 100× faster than intended.

Fix Always check the scheduler's documentation. OneCycleLR, LinearLR, CosineAnnealingLR (when T_max = total steps): call after every optimizer.step(). StepLR, MultiStepLR: call after every epoch (outside the batch loop). HuggingFace's get_scheduler always expects per-step calling.
Trap Setting weight_decay=0 when using AdamW, making it identical to Adam

AdamW with weight_decay=0 is mathematically identical to Adam — the decoupling mechanism has nothing to decouple. Engineers use AdamW (correctly) but forget to set weight_decay, losing all regularisation. The model overfits on small fine-tuning datasets with no indication that regularisation is missing — loss and accuracy metrics may look fine on training data.

Fix Set weight_decay=0.01 for Transformer fine-tuning (standard from BERT paper). For LLM pre-training: 0.01 to 0.1. For CNNs: 1e-4. Remember to use parameter groups — apply weight decay only to non-bias, non-LayerNorm parameters.

At the start of training, Adam's 2nd moment estimate v̂ is initialised to 0 and poorly estimated for the first hundreds of steps — the bias-corrected estimate v̂_t is noisy. The adaptive step size η/(√v̂+ε) is unreliable: some parameters receive enormous updates (where v̂ is accidentally small) and others receive tiny updates. Additionally, random initialisation means gradients are large and noisy early. Training with the full learning rate immediately can push the model into a bad local region it cannot escape, or cause NaN losses from large weight updates. Warmup linearly increases LR from ~0 to η_max over the first few hundred steps, allowing moment estimates to stabilise while making cautious initial updates. Empirically: removing warmup from BERT fine-tuning causes 1-3% accuracy degradation and training instability, especially with larger learning rates.

The 1-cycle policy (Leslie Smith, 2018) cycles the LR through one full ramp-up → ramp-down cycle per training run. The key insight from the LR range test: there exists a maximum stable LR well above the value that would cause divergence if constant. Cycling up to this high LR acts as a form of regularisation — it pushes the model out of sharp narrow minima (which generalise poorly) into flatter wider regions (which generalise better). The inverse momentum cycle (high momentum when LR is low) provides implicit annealing. The final very-low LR phase allows fine-grained convergence within the flat minimum. Super-convergence result: models train to the same or better accuracy in 5-10× fewer epochs. Most impactful for CNNs and image classifiers; cosine schedules have largely replaced it for Transformers where stability during warmup is more critical.

The learning rate is the most important hyperparameter. But the scheduler is nearly as important — warmup prevents early instability, cosine decay avoids sharp minima, and decoupled weight decay (AdamW) provides consistent regularisation across all parameters.
Regularization & Init 05–06
05

Normalisation Layers, Dropout & Weight Initialisation

Your CNN trains fine at batch_size=64 on an A100, but you port it to real-time edge inference at batch_size=1 and accuracy collapses. BatchNorm's running statistics diverge from single-sample statistics. This is the classic BN failure mode — understanding when to use BN vs LayerNorm vs GroupNorm is the difference between a model that works in the lab and one that ships.

Normalisation Layers — BatchNorm, LayerNorm, GroupNorm
BN: normalise across batch per channel (needs large batch); LN: across features per sample (Transformer standard); GN: channel groups (segmentation/video)

Normalisation stabilises training by standardising layer inputs, preventing internal covariate shift. BatchNorm (Ioffe & Szegedy, 2015): For each channel C, normalise across the batch N: μ_B = (1/m)Σᵢ xᵢ (batch mean) σ²_B = (1/m)Σᵢ(xᵢ − μ_B)² (batch variance) x̂ᵢ = (xᵢ − μ_B) / √(σ²_B + ε) yᵢ = γ·x̂ᵢ + β (learnable scale γ and shift β) Training: batch stats. Inference: running mean/var (EMA). Requires batch_size ≥ 16. Breaks at batch_size=1 (variance=0 → NaN). LayerNorm (Ba et al, 2016): Normalises across ALL features for each sample independently: μ = (1/H)Σⱼ xⱼ; σ² = (1/H)Σⱼ(xⱼ − μ)² (over feature dim H) Works at any batch size including batch_size=1. Standard in every modern Transformer: BERT, GPT, ViT, LLaMA. GroupNorm (Wu & He, 2018): Divides C channels into G groups; normalises within each group. G=1 → equivalent to LayerNorm. G=C → InstanceNorm. Independent of batch size. Used in segmentation, video, RL.

  Tensor shape: [N=Batch, C=Channels] (spatial dims omitted)

  BatchNorm — normalises across N for each C:
  ┌─────┬─────┬─────┬─────┐
  │ N=0 │ N=1 │ N=2 │ N=3 │  C=0  ← μ,σ² computed over these 4
  ├─────┼─────┼─────┼─────┤
  │ N=0 │ N=1 │ N=2 │ N=3 │  C=1  ← μ,σ² computed over these 4
  └─────┴─────┴─────┴─────┘
    ↕ normalised along batch axis  — fails at N=1

  LayerNorm — normalises across C for each N:
  ┌─────┬─────┬─────┬─────┐
  │ C=0 │ C=1 │ C=2 │ C=3 │  N=0  ← μ,σ² computed over these 4
  ├─────┼─────┼─────┼─────┤
  │ C=0 │ C=1 │ C=2 │ C=3 │  N=1  ← μ,σ² computed over these 4
  └─────┴─────┴─────┴─────┘
    ↔ normalised along feature axis — works at any batch size
PyTorch — BatchNorm train vs eval, LayerNorm verification, GroupNorm at batch=1
import torch
import torch.nn as nn

# ── BatchNorm: training vs eval behaviour ─────────────────
bn = nn.BatchNorm1d(4)
bn.train()
x = torch.tensor([[1., 2., 3., 4.],
                   [5., 6., 7., 8.],
                   [2., 4., 6., 8.]])     # batch=3, features=4
out = bn(x)
print("BN train (each column normalised across batch):")
print(out.detach().round(decimals=3))
print(f"Running mean: {bn.running_mean}")

bn.eval()
x_one = torch.tensor([[3., 4., 5., 6.]])  # batch_size=1 at inference
out_eval = bn(x_one)
print(f"\nBN eval (uses running stats): {out_eval.detach()}")

# ── LayerNorm: verify against manual formula ──────────────
ln = nn.LayerNorm(4)
x  = torch.tensor([[1., 2., 3., 4.]])
out_ln = ln(x)
mean = x.mean(dim=-1, keepdim=True)
std  = x.std(dim=-1, keepdim=True, unbiased=False)
manual = (x - mean) / (std + 1e-5)
print(f"\nLayerNorm output: {out_ln.detach()}")
print(f"Manual LN:        {manual.detach()}")

# ── GroupNorm works at batch_size=1 ──────────────────────
gn   = nn.GroupNorm(num_groups=2, num_channels=8)
x_gn = torch.randn(1, 8, 16, 16)       # batch=1 — works fine
out_gn = gn(x_gn)
print(f"\nGroupNorm: {x_gn.shape} → {out_gn.shape}  (batch=1 ✓)")

# ── LayerNorm in Transformer input shape ─────────────────
norm = nn.LayerNorm(64)
x_t  = torch.randn(4, 32, 64)           # (batch, seq_len, d_model)
out_t = norm(x_t)
print(f"\nLN in Transformer: {x_t.shape} → {out_t.shape}")
print(f"  Mean across d_model ≈ 0: {out_t.mean(dim=-1).abs().mean():.6f}")
print(f"  Std  across d_model ≈ 1: {out_t.std(dim=-1).mean():.4f}")
Trap Using BatchNorm with batch_size=1 at inference

In eval mode, BN uses running_mean and running_var accumulated during training. If training batch statistics differ significantly from the test distribution (e.g., different domain, or the model was evaluated mid-training), running stats diverge from actual activation stats. With batch_size=1, the normalisation is based on noisy running estimates and accuracy drops. This is common when porting a trained model to a streaming inference pipeline.

Fix Replace BN with GN (GroupNorm) or LN for models that will be deployed at batch_size=1 or in streaming settings. If BN is required: run a calibration pass over representative data in eval mode with a custom hook that recomputes running stats on the target distribution.
Trap Placing LayerNorm after the residual addition (post-LN) instead of before (pre-LN)

The original Transformer used post-LN: x = LN(x + sublayer(x)). This causes training instability for very deep models (12+ layers) because gradients must pass through the LN at each layer before reaching earlier layers, and LN can scale them down. GPT-2 and most modern architectures switched to pre-LN: x = x + sublayer(LN(x)).

Fix Use pre-LN for any Transformer with more than 6 layers: apply LN to the input of each sublayer (attention and FFN), not to the output. HuggingFace models default to pre-LN. If training a custom Transformer and seeing instability, this is the first thing to switch.

BatchNorm normalises across the batch dimension. At batch_size=1, there is only one sample — the batch variance is 0 (or undefined with unbiased=False). Dividing by √(0 + ε) gives a very small normalisation factor, effectively scaling activations by 1/√ε which can be large, producing garbage outputs. Even in eval mode with running stats, a single sample deviates from the batch statistics. Alternatives that work at batch_size=1: (1) LayerNorm — normalises over features per sample independently; (2) GroupNorm — normalises within channel groups per sample; (3) InstanceNorm — normalises per channel per sample (used in style transfer). For production deployment at batch_size=1: switch BN to GN or LN during architecture design, not after training.

Internal covariate shift (Ioffe & Szegedy, 2015) is the change in the distribution of each layer's inputs during training as the parameters of preceding layers change. Early layers update their weights, shifting the distribution seen by later layers. Later layers must constantly re-adapt, slowing convergence. BatchNorm addresses this by forcing each layer's input to have approximately zero mean and unit variance (before the learnable γ, β), making the distribution consistent regardless of upstream parameter changes. LayerNorm applies the same principle but per-sample — the distribution of features for each individual sample is standardised. This is why normalisation allows much higher learning rates: the gradient signal is more consistent and layers are less sensitive to the scale of upstream parameters.

Dropout & Weight Initialisation — He, Xavier, Orthogonal
Inverted dropout (train scale, eval pass-through); He init for ReLU; Xavier for tanh/sigmoid; orthogonal for RNNs; verify init with activation variance

Weight Initialisation: Goal: activation variance ≈ 1.0 throughout the network after the first forward pass. Too small → signals shrink to 0 (vanishing activations) Too large → signals grow to ∞ (exploding activations / NaN) Xavier/Glorot (for sigmoid/tanh): Var(W) = 2/(fan_in + fan_out) Symmetric — balances forward and backward signal variance. nn.init.xavier_uniform_(w) or xavier_normal_(w) He/Kaiming (for ReLU): Var(W) = 2/fan_in Factor of 2: ReLU zeros half the inputs on average → halves effective fan_in. PyTorch default for nn.Linear and nn.Conv2d. nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu') Orthogonal (for RNNs / state-space models): W initialised as a random orthogonal matrix: WᵀW = I Preserves gradient norms exactly across recurrent steps. nn.init.orthogonal_(w) Inverted Dropout (Srivastava et al, 2014): Training: zero activation with probability p; scale survivors by 1/(1−p) Inference: pass all activations unchanged (no scaling needed) Inverted scaling keeps expected activation value constant across train/eval. Typical p: 0.5 for MLP on small datasets; 0.1 for Transformers.

PyTorch — He init keeps variance near 1.0; inverted dropout proof; Xavier comparison
import torch
import torch.nn as nn

# ── He init: activation variance stays near 1.0 ──────────
layers = nn.ModuleList([nn.Linear(256, 256) for _ in range(8)])
for layer in layers:
    nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
    nn.init.zeros_(layer.bias)

x = torch.randn(128, 256)
print("Activation variance per layer (He init, ReLU):")
for i, layer in enumerate(layers):
    x = torch.relu(layer(x))
    ok = '✓' if 0.3 < x.var() < 3.0 else '✗'
    print(f"  layer {i+1}: var={x.var():.4f}  {ok}")

# ── Naive tiny init: activations collapse ────────────────
layers_bad = nn.ModuleList([nn.Linear(256, 256) for _ in range(8)])
for layer in layers_bad:
    nn.init.normal_(layer.weight, 0, 0.01)
    nn.init.zeros_(layer.bias)

x_bad = torch.randn(128, 256)
print("\nActivation variance per layer (std=0.01, ReLU):")
for i, layer in enumerate(layers_bad):
    x_bad = torch.relu(layer(x_bad))
    print(f"  layer {i+1}: var={x_bad.var():.2e}  {'✗ vanished' if x_bad.var() < 1e-4 else ''}")

# ── Inverted dropout: training vs eval ────────────────────
torch.manual_seed(0)
x_d = torch.ones(1000)
p   = 0.5
dropout = nn.Dropout(p=p)

dropout.train()
out_tr = dropout(x_d)
print(f"\nDropout train: zeroed={( out_tr==0).float().mean():.2f}, "
      f"survivor mean={out_tr[out_tr>0].mean():.2f}  (scaled by 1/(1-p)={1/(1-p):.1f})")

dropout.eval()
out_ev = dropout(x_d)
print(f"Dropout eval:  all passed, mean={out_ev.mean():.2f}")

# ── Xavier vs He: weight std comparison ──────────────────
w_he  = torch.empty(256, 256); nn.init.kaiming_uniform_(w_he, nonlinearity='relu')
w_xav = torch.empty(256, 256); nn.init.xavier_uniform_(w_xav)
print(f"\nHe    init std: {w_he.std():.4f}  (for ReLU)")
print(f"Xavier init std: {w_xav.std():.4f}  (for tanh/sigmoid)")
Trap Keeping dropout enabled during model.eval() forward pass

nn.Dropout is a stateful module that behaves differently in training and eval modes. Engineers who call model.forward() directly (bypassing model.eval()) or who forget to call model.eval() before inference get dropout applied at inference time — outputs are non-deterministic and approximately (1-p) of the expected value. This causes mysteriously variable and degraded inference performance.

Fix Always call model.eval() before any inference or validation loop. This single call switches ALL dropout and BatchNorm layers to eval mode simultaneously. Verify: torch.is_grad_enabled() should be False (you have no_grad) and model.training should be False.
Trap Using the same weight initialisation for all activation functions

He init is derived assuming ReLU zeros half the activations — it produces variance 2/fan_in. Applying He init to a tanh or sigmoid network gives variance 2× too large, causing saturation in the first forward pass. Conversely, Xavier init in a ReLU network gives variance 1/fan_in — 2× too small — leading to vanishing activations. Both fail silently; training runs but converges poorly.

Fix Match init to activation: ReLU/Leaky ReLU → kaiming_normal_(nonlinearity='relu'); tanh/sigmoid → xavier_uniform_; GELU → kaiming_normal_(nonlinearity='relu') works well as an approximation. For custom activations: derive the gain factor analytically or use nn.init.calculate_gain('tanh').

Both derive from the requirement that activation variance stays constant across layers. Xavier assumes a linear activation (or symmetric activation with gradient ≈ 1): to keep Var(output) = Var(input), you need Var(W) = 1/fan_in. Averaging over forward and backward passes gives 2/(fan_in + fan_out). He init assumes ReLU: ReLU zeros approximately half the activations on average (the negative half), effectively halving the fan_in. To compensate: Var(W) = 2/fan_in. If you use Xavier init with ReLU, the expected output variance is half what it should be — activations systematically shrink by √2 at each layer. For 10 layers: 0.5^5 = 0.03 — activations are 30× smaller than inputs, causing vanishing signal.

Standard dropout: zero each activation with probability p during training; at inference, scale all activations by (1-p) to match expected values. Problem: inference code must remember to apply the scaling factor. Inverted dropout flips this: during training, zero activations with probability p AND scale survivors by 1/(1-p). Expected value is preserved: E[output] = (1-p) * x/(1-p) = x. At inference: use all activations without any scaling — inference code is simpler and cannot forget the scaling. PyTorch's nn.Dropout uses inverted dropout. The practical impact: if you implement dropout manually (e.g., in a custom layer or NumPy prototype) and forget the 1/(1-p) scaling, inference performs at (1-p) of expected accuracy with no error message.

Normalisation layer choice determines not just training stability but where your model can run. BatchNorm requires batch statistics — it breaks at batch_size=1, with variable-length sequences, and in online learning. LayerNorm and GroupNorm trade some training speed for universal applicability.
06

Data Augmentation, Mixup, CutMix & Training Regularisation

You train an image classifier to 95% train accuracy but only 72% validation accuracy. Adding random crops and colour jitter takes validation accuracy to 84% — without a single additional sample. Data augmentation is the most compute-efficient regulariser in computer vision and the first thing to add when you see a large train/val gap.

Augmentation — Geometric, Colour, Mixup, CutMix & Label Smoothing
torchvision.transforms.v2 pipeline; Mixup: x̃=λxᵢ+(1-λ)xⱼ, ỹ=λyᵢ+(1-λ)yⱼ; CutMix patches; label_smoothing=0.1 in CrossEntropyLoss

Data augmentation artificially increases training set diversity without collecting new data. Geometric (spatial): RandomResizedCrop(224, scale=(0.08,1.0)): crop random fraction then resize RandomHorizontalFlip(p=0.5): mirror left-right RandomRotation(degrees=15): rotate ±15° — only for rotation-invariant tasks Colour: ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1) RandomGrayscale(p=0.2): converts to grayscale with probability p GaussianBlur(kernel_size=23): random blur strength Mixing augmentations: Mixup (Zhang et al, 2018): x̃ = λxᵢ + (1−λ)xⱼ; ỹ = λyᵢ + (1−λ)yⱼ λ ~ Beta(α,α), α=0.2 Soft labels force smooth prediction surface. +1-2% on ImageNet. CutMix (Yun et al, 2019): Paste a random rectangular region from xⱼ onto xᵢ. ỹ = λyᵢ + (1−λ)yⱼ where λ = fraction of pixels from xᵢ. Preserves local structure better than Mixup. Standard in ViT training. Label Smoothing: ỹ = (1−ε)·y_onehot + ε/K ε=0.1, K=num_classes Prevents overconfident logits. Built in: nn.CrossEntropyLoss(label_smoothing=0.1)

PyTorch — transforms.v2 pipeline, Mixup, CutMix, label smoothing effect on loss
import torch
import torch.nn as nn
import torchvision.transforms.v2 as T

# ── ImageNet training pipeline (transforms.v2) ───────────
train_tfm = T.Compose([
    T.RandomResizedCrop(224, scale=(0.08, 1.0), ratio=(0.75, 1.33)),
    T.RandomHorizontalFlip(p=0.5),
    T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    T.RandomGrayscale(p=0.2),
    T.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0)),
    T.ToImage(),
    T.ToDtype(torch.float32, scale=True),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
val_tfm = T.Compose([
    T.Resize(256), T.CenterCrop(224),
    T.ToImage(), T.ToDtype(torch.float32, scale=True),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
print("Transforms defined (v2 API: batched ops, 2-4x faster than v1)")

# ── Mixup: manual implementation ─────────────────────────
def mixup(x, y, alpha=0.2, num_classes=10):
    lam = float(torch.distributions.Beta(alpha, alpha).sample())
    idx = torch.randperm(x.size(0))
    x_mix = lam * x + (1 - lam) * x[idx]
    ya_oh = torch.zeros(x.size(0), num_classes).scatter_(1, y.unsqueeze(1), 1.)
    y_mix = lam * ya_oh + (1 - lam) * ya_oh[idx]
    return x_mix, y_mix

x_b = torch.randn(8, 3, 224, 224)
y_b = torch.randint(0, 10, (8,))
x_m, y_m = mixup(x_b, y_b, alpha=0.2)
print(f"\nMixup: {x_b.shape} → {x_m.shape},  soft label sums: {y_m.sum(dim=1).tolist()}")

# ── CutMix via torchvision.transforms.v2 ─────────────────
cutmix = T.CutMix(num_classes=10, alpha=1.0)
x_c  = torch.rand(4, 3, 224, 224)
y_c  = torch.randint(0, 10, (4,))
x_cm, y_cm = cutmix(x_c, y_c)
print(f"CutMix: x {x_cm.shape}, soft_y {y_cm.shape}")

# ── Label smoothing: penalises overconfident logits ──────
logits = torch.tensor([[0., 0., 10.]])    # very confident prediction
label  = torch.tensor([2])
loss_hard = nn.CrossEntropyLoss(label_smoothing=0.0)(logits, label)
loss_soft = nn.CrossEntropyLoss(label_smoothing=0.1)(logits, label)
print(f"\nCE (no smooth): {loss_hard.item():.4f}")
print(f"CE (smooth=0.1): {loss_soft.item():.4f}  (higher — penalises overconfidence)")
Trap Applying training augmentation to the validation set

Using the same T.Compose pipeline for both train and val transforms causes the validation loss to be stochastic — different random crops and colour jitters each epoch. This makes early stopping unreliable (you are comparing noisy loss estimates), inflates apparent variance, and can make overfitting look like noise. Validation must be deterministic.

Fix Always define two separate transform pipelines: train_transform with random ops; val_transform with only deterministic ops (Resize + CenterCrop + Normalize). Pass them to separate Dataset instances. Verify by running the same val batch twice — results should be identical.
Trap Using Mixup or CutMix with hard-label cross-entropy loss (no soft-label support)

Mixup and CutMix produce soft labels (float distributions over classes). Standard nn.CrossEntropyLoss expects integer class indices. Passing a soft label tensor to CrossEntropyLoss with reduction="mean" and integer labels causes a shape mismatch error or, worse, silently computes the wrong loss if shapes accidentally match.

Fix When using Mixup/CutMix, use nn.CrossEntropyLoss with the soft label directly: loss = criterion(logits, y_soft) where y_soft is the float distribution tensor. nn.CrossEntropyLoss accepts both integer labels and float soft-label distributions. torchvision.transforms.v2.MixUp and CutMix return soft labels by default.

Mixup trains the model on convex combinations of training examples and their labels. The loss is: L = λ·CE(f(x̃), yᵢ) + (1-λ)·CE(f(x̃), yⱼ). This forces the model to predict linearly interpolated probability distributions for interpolated inputs, making the learned function smoother between training examples. The key regularisation effect: without Mixup, the model can learn very sharp decision boundaries that overfit to training noise. Mixup penalises sharp boundaries — the model cannot confidently predict yᵢ for inputs slightly mixed with xⱼ without also being partially confident about yⱼ. This flattens the prediction surface, improving calibration and generalisation. Empirically: +1-2% top-1 accuracy on ImageNet, stronger gains on smaller datasets.

Label smoothing converts one-hot labels y_hard into soft distributions: ỹ = (1-ε)·y_hard + ε/K. Instead of the model targeting probability=1.0 for the true class, it targets (1-ε+ε/K) ≈ 0.9 for ε=0.1. This prevents the model from becoming overconfident — logits cannot grow arbitrarily large without increasing the loss. It improves calibration (predicted probabilities better match empirical accuracy) and reduces overfitting on noisy labels. When NOT to use it: (1) knowledge distillation — the soft targets from the teacher already encode class relationships; adding label smoothing distorts the teacher signal; (2) when you need exact top-1 argmax accuracy as your metric — smoothing rarely hurts accuracy but does modify confidence scores; (3) regression tasks — label smoothing is for classification.

Early Stopping, Learning Curves & Overfitting Diagnosis
Train↓ val↑ = overfit; patience-based early stopping with best-checkpoint restore; WandB logging pattern

Learning curve patterns: Train ↓, val ↓ (parallel): model is learning, no overfitting yet Train ↓, val flat: plateau — near optimal capacity or LR issue Train ↓, val ↑: overfitting — stop here, restore best checkpoint Both flat from epoch 1: stuck — LR too low, dead neurons, bad data Early Stopping: Monitor val_loss every epoch. Stop when val_loss has not improved by min_delta for patience epochs. Restore best model weights from before divergence. patience: 5-10 for fine-tuning; 20-50 for training from scratch. min_delta: 1e-4 typical (ignore noise). LR range test (Smith, 2017): Increase LR exponentially from 1e-7 to 10 over one epoch. Plot loss vs LR. Best LR = just before loss starts rising sharply. Training LR = that value / 10. Finds optimal LR in minutes. WandB logging pattern: wandb.log({ "train/loss": train_loss, "val/loss": val_loss, "lr": sched.get_last_lr()[0], "grad_norm": total_norm }) Log per-step for train, per-epoch for val. Tag with hyperparameters.

PyTorch — early stopping class with best-checkpoint restore, WandB logging skeleton
import torch
import torch.nn as nn

class EarlyStopping:
    def __init__(self, patience=7, min_delta=1e-4):
        self.patience   = patience
        self.min_delta  = min_delta
        self.best_loss  = float('inf')
        self.counter    = 0
        self.best_state = None

    def step(self, val_loss, model):
        """Returns True when training should stop."""
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss  = val_loss
            self.counter    = 0
            self.best_state = {k: v.clone() for k, v in model.state_dict().items()}
        else:
            self.counter += 1
        return self.counter >= self.patience

    def restore(self, model):
        if self.best_state:
            model.load_state_dict(self.best_state)
            print(f"Restored best model (val_loss={self.best_loss:.4f})")

# ── Training loop with early stopping ─────────────────────
torch.manual_seed(42)
model   = nn.Sequential(nn.Linear(10, 64), nn.ReLU(), nn.Linear(64, 1))
optim   = torch.optim.Adam(model.parameters(), lr=5e-3)
stopper = EarlyStopping(patience=5, min_delta=1e-3)
X_tr, y_tr   = torch.randn(200, 10), torch.randn(200, 1)
X_val, y_val = torch.randn(50, 10),  torch.randn(50, 1)

for epoch in range(200):
    model.train()
    optim.zero_grad()
    loss = nn.MSELoss()(model(X_tr), y_tr)
    loss.backward(); optim.step()

    model.eval()
    with torch.no_grad():
        val_loss = nn.MSELoss()(model(X_val), y_val).item()

    if epoch % 20 == 0:
        print(f"Epoch {epoch:3d}: train={loss.item():.4f}  val={val_loss:.4f}  "
              f"patience={stopper.counter}/{stopper.patience}")
    if stopper.step(val_loss, model):
        print(f"\nEarly stop at epoch {epoch}.")
        stopper.restore(model)
        break

# ── WandB logging skeleton ────────────────────────────────
print("""
WandB integration:
  import wandb
  wandb.init(project='dl-run', config={'lr': 1e-3, 'wd': 0.01})
  for step, batch in enumerate(loader):
      ...train...
      wandb.log({'train/loss': loss, 'lr': sched.get_last_lr()[0]}, step=step)
  for epoch in range(epochs):
      ...validate...
      wandb.log({'val/loss': val_loss, 'val/acc': acc}, step=epoch)
""")
Trap Using training loss for early stopping decisions

Training loss always decreases with more epochs (the model memorises training data). Using it for early stopping means the model never stops — it always "improves" on the training set. The result is a maximally overfitted model. This is a common mistake when validation set creation is forgotten or when validation is expensive.

Fix Always early stop on validation loss (or a validation metric like AUC or F1). If a separate validation set is expensive, use k-fold cross-validation with early stopping on the held-out fold. Never use training metrics for stopping decisions.
Trap Not saving the best checkpoint before early stopping restores

Training continues past the optimal point while patience counts down. Without saving the best weights, restoring "best model" is impossible — the model state has been overwritten by subsequent (worse) epochs. Engineers who only save the final checkpoint lose the best model entirely.

Fix Save a checkpoint every time validation loss improves: torch.save(model.state_dict(), 'best.pt'). Restore at the end: model.load_state_dict(torch.load('best.pt')). The EarlyStopping class above does this in-memory with .clone(). For large models: save to disk to avoid doubling memory usage.

This pattern — train decreasing, val flat — indicates the model has saturated the validation set's information content. Possible causes: (1) model capacity is right but the problem is genuinely hard and a plateau is near-optimal; (2) learning rate is too high — the model is oscillating around a good minimum without settling; (3) insufficient regularisation — model is beginning to overfit but validation noise masks the signal; (4) the validation set is too small and val loss is noisy. Distinguish by: reduce LR by 10× and watch if val loss improves (cause: LR issue); add dropout or weight decay (cause: overfitting starting); increase patience (cause: noise). If val loss stays flat for 20+ epochs with good regularisation, you may have reached the model's performance ceiling on your data.

The LR range test (Leslie Smith, 2017): run one epoch with the learning rate increasing exponentially from a very small value (1e-7) to a large value (10), logging the loss at each step. Plot loss vs LR. At low LR: loss decreases slowly (under-stepping). At some threshold: loss decreases rapidly (optimal region). Beyond a cliff: loss spikes or diverges. The test reveals: (1) the minimum LR where training makes meaningful progress; (2) the maximum stable LR (the cliff). The practical training LR is the cliff value / 10. This takes minutes and replaces hours of grid search. In the 1-cycle policy, η_max is set to the cliff LR. In fastai: learn.lr_find() runs this automatically and plots it.

Every augmentation operation encodes a domain assumption. Horizontal flip says "left-right orientation does not matter." Mixup says "predictions should interpolate smoothly between classes." Label smoothing says "your training labels are not perfectly correct." Choosing augmentations is choosing inductive biases.
Architectures 07–08
07

CNN Architectures, ResNet Skip Connections & Transfer Learning

A new engineer asks why ResNet trains reliably at 152 layers while a plain 152-layer network with identical parameters completely fails to converge. The answer — the skip connection — is one of the most impactful architectural insights in deep learning. Understanding it at the gradient level is what lets you debug deep networks rather than restarting and hoping.

Convolutions, Pooling, Receptive Field & ResNet Skip Connections
Output size: (H-K+2P)/S+1; RF accumulates with depth; MaxPool vs AdaptiveAvgPool; skip connection gradient identity term

Convolution operation: Out[c_out, i, j] = Σ_{c_in,k,l} W[c_out,c_in,k,l]·X[c_in,i+k,j+l] + b Output size: H_out = (H_in − K + 2P) / S + 1 (K=kernel, P=padding, S=stride) Key properties: - Parameter sharing: same K×K kernel at every position → translation equivariance - Local connectivity: each output sees only a local region of the input - K=3, P=1, S=1: H_out = H_in (same-padding — standard for ResNets) Effective Receptive Field: Single layer, kernel K: RF = K L layers: RF = 1 + L·(K-1) (K=3, L=10 → RF=21) Stride-2 conv or MaxPool(2): multiplies RF by 2 ResNet50: effective RF > 900×900 — larger than ImageNet images (224×224) Pooling: MaxPool2d(2, 2): 2×2 window, stride 2 — halves H,W; keeps strongest activations AdaptiveAvgPool2d((1,1)): collapses to 1×1 regardless of input size → global average Standard ResNet head: ... → AdaptiveAvgPool → flatten → Linear(num_classes) ResNet skip connection: y = F(x, {Wᵢ}) + x (identity shortcut, same dimensions) y = F(x, {Wᵢ}) + Wₛx (projection shortcut, different dims) ∂L/∂x = ∂L/∂y · (∂F(x)/∂x + I) Identity term I guarantees gradient always flows — even if ∂F/∂x → 0.

  ResNet basic block:

  input x ───────────────────────────────────────┐
     │                                             │  (identity path)
     ▼                                             │
  Conv2d(K=3, P=1, S=1)                          │
  BatchNorm2d  →  ReLU                            │
     │                                             │
  Conv2d(K=3, P=1, S=1)                          │
  BatchNorm2d                                     │
     │                                             │
     ▼                                             │
     +  ←─────────────────────────────────────────┘
     │   element-wise add
     ▼
  ReLU  →  output h = F(x) + x

  Gradient:  ∂L/∂x = ∂L/∂h · (∂F(x)/∂x + I)
  +I term: gradient flows even when ∂F/∂x → 0 (vanishing branch)
PyTorch — conv output size, ResNet BasicBlock, gradient flow check, AdaptiveAvgPool head
import torch
import torch.nn as nn
import torch.nn.functional as F

# ── Conv output size formula ──────────────────────────────
def conv_out(H, K, P, S):
    return (H - K + 2*P) // S + 1

for K, P, S, label in [(3,1,1,'same'), (3,0,2,'stride-2'), (1,0,1,'pointwise'), (7,3,2,'stem')]:
    print(f"K={K} P={P} S={S} ({label}): 224 → {conv_out(224,K,P,S)}")

# ── ResNet BasicBlock ─────────────────────────────────────
class BasicBlock(nn.Module):
    def __init__(self, in_ch, out_ch, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(out_ch)
        self.skip  = nn.Identity()
        if stride != 1 or in_ch != out_ch:
            self.skip = nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 1, stride=stride, bias=False),
                nn.BatchNorm2d(out_ch)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        return F.relu(out + self.skip(x))     # skip connection

# Same-dimension block
blk1 = BasicBlock(64, 64, stride=1)
x    = torch.randn(4, 64, 56, 56)
print(f"\nSame-dim:  {x.shape} → {blk1(x).shape}")

# Downsample block (stride=2, channel doubling)
blk2 = BasicBlock(64, 128, stride=2)
print(f"Downsample: {x.shape} → {blk2(x).shape}")

# ── Verify skip connection passes gradient ────────────────
x_g = torch.randn(2, 64, 56, 56, requires_grad=True)
blk1(x_g).sum().backward()
print(f"Gradient at input (skip block): {x_g.grad.norm():.4f}  (non-zero = gradient flows)")

# ── Standard classification head ─────────────────────────
pool  = nn.AdaptiveAvgPool2d((1, 1))
head  = nn.Linear(128, 10)
feat  = torch.randn(4, 128, 7, 7)
logit = head(pool(feat).flatten(1))
print(f"\nHead: {feat.shape} → pool → flatten → {logit.shape}")
Trap Using stride-1 convolutions everywhere (no downsampling), causing memory explosion

Deep CNNs require progressive spatial downsampling to keep memory tractable. A 224×224 feature map at 256 channels is 224×224×256×4 bytes ≈ 51MB per sample. Without stride-2 convolutions or pooling to reduce spatial dimensions, a 20-layer CNN on a single 224×224 image requires GBs of activation memory. Training will OOM on any reasonable GPU.

Fix Follow the standard ResNet pattern: stem conv at stride=2, one MaxPool(2), then three more stride-2 blocks — total 8× spatial reduction from 224 to 28. At each downsampling step, double the channel count to compensate. AdaptiveAvgPool at the end collapses spatial dimensions completely before the classifier.
Trap Computing receptive field wrong and believing your model "sees" the full image

A 5-layer CNN with K=3 has theoretical RF = 1+5×2 = 11 pixels — nowhere near the full image. Engineers assume their model integrates global context but it only sees a small local patch at each output position. This explains why models fail on tasks requiring global reasoning (image-level classification vs patch classification). The effective RF is often smaller than the theoretical RF due to weight magnitude patterns.

Fix Calculate RF before designing the architecture. Use stride-2 convolutions and pooling to grow the RF efficiently. For global context: add a self-attention layer or use a global pooling operation. Alternatively: use ViT which has global receptive field from the first layer via attention.

In a plain 152-layer network, gradients must propagate through all 152 layers multiplicatively. Each layer's Jacobian can shrink gradients (vanishing) or grow them (exploding). Even with ReLU and He init, 152 multiplicative terms produce near-zero gradients at early layers. In ResNet-152, each block adds a gradient highway: ∂L/∂x = ∂L/∂y(∂F/∂x + I). The +I identity term guarantees a lower bound on gradient magnitude regardless of how the residual branch F(x) behaves. Gradients can flow from the output to any layer with only additive attenuation, not multiplicative. Empirically: He et al (2015) showed plain-56 and plain-110 performed worse than plain-18 (more layers hurts without skips), while ResNet-56 and ResNet-110 consistently beat ResNet-18.

Theoretical RF: the total region of the input that can influence a given output neuron, computed as 1 + L(K-1) for L layers of kernel size K. Effective RF (Luo et al, 2017): the region that actually has significant influence on the output, weighted by gradient magnitude. Due to weight initialisation and the central limit theorem, central input positions have disproportionately large influence — the effective RF is approximately Gaussian-shaped and much smaller than the theoretical RF (often 10-20× smaller). Implication: your model does not uniformly "see" its entire theoretical RF. To increase effective RF: use dilated (atrous) convolutions (expands RF without increasing parameters), depthwise separable convolutions at multiple scales, or self-attention mechanisms which have O(1) receptive field (every position attends to every other).

Transfer Learning — Feature Extraction vs Fine-tuning
Freeze backbone for small datasets; layer-wise LR decay for fine-tuning; gradual unfreezing; torchvision weights API

Transfer learning: adapt a model pretrained on large data (ImageNet, LAION) to your task. Feature Extraction (frozen backbone): for p in backbone.parameters(): p.requires_grad = False Train ONLY the new classification head. Use when: small dataset (<1000 images per class); similar domain. Fast: only the head is updated; avoids catastrophic forgetting. The backbone is a fixed feature extractor — frozen knowledge. Fine-tuning (unfrozen backbone): Keep pretrained weights as initialisation; train ALL layers. Apply lower LR to early layers (they contain general features worth preserving): backbone_lr = 1e-5; head_lr = 1e-3 (layer-wise LR decay) Use when: ≥1000 samples per class; domain differs from pretraining. Risk: catastrophic forgetting if LR is too high. Gradual Unfreezing: Epoch 1: head only → Epoch 2-3: last 2 blocks + head → Epoch 4+: all layers Prevents early large updates that corrupt pretrained features. fastai discriminative fine-tuning recipe uses this approach. torchvision weights API: weights = ResNet50_Weights.IMAGENET1K_V2 (80.9% top-1) model = resnet50(weights=weights) transform = weights.transforms() (correct preprocessing built-in)

PyTorch — feature extraction pattern, layer-wise LR groups for fine-tuning, built-in transforms
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.models import ResNet50_Weights

# ── Feature extraction: freeze everything, train only head ─
backbone = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
for p in backbone.parameters():
    p.requires_grad = False

# Replace final classifier (in_features=2048 for ResNet50)
backbone.fc = nn.Linear(backbone.fc.in_features, 5)   # 5-class task
# Only backbone.fc has requires_grad=True

trainable_params = sum(p.numel() for p in backbone.parameters() if p.requires_grad)
total_params     = sum(p.numel() for p in backbone.parameters())
print(f"Feature extraction: {trainable_params:,} / {total_params:,} trainable")

# ── Fine-tuning: layer-wise LR decay ─────────────────────
backbone2 = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
backbone2.fc = nn.Linear(backbone2.fc.in_features, 5)

base_lr = 1e-4
param_groups = [
    {'params': backbone2.layer1.parameters(), 'lr': base_lr * 0.01},
    {'params': backbone2.layer2.parameters(), 'lr': base_lr * 0.1},
    {'params': backbone2.layer3.parameters(), 'lr': base_lr * 0.5},
    {'params': backbone2.layer4.parameters(), 'lr': base_lr * 1.0},
    {'params': backbone2.fc.parameters(),     'lr': base_lr * 10.},  # head: highest LR
]
optimizer = torch.optim.AdamW(param_groups, weight_decay=0.01)
print("\nLayer-wise LR groups:")
for i, pg in enumerate(optimizer.param_groups):
    print(f"  group {i}: lr={pg['lr']:.6f}")

# ── Use built-in preprocessing — matches training exactly ─
transform = ResNet50_Weights.IMAGENET1K_V2.transforms()
print(f"\nBuilt-in preprocessing: {transform}")

# Verify forward pass
x = torch.randn(2, 3, 224, 224)
with torch.no_grad():
    out = backbone2(x)
print(f"\nForward pass: {x.shape} → {out.shape}")
Trap Using the same learning rate for backbone and head during fine-tuning

Pretrained backbone weights encode valuable low-level features (edges, textures) in early layers and high-level semantics in later layers. Using a large uniform LR destroys these features in the first few gradient steps — the model forgets what it learned and must relearn from scratch, negating the benefit of pretraining. This is catastrophic forgetting.

Fix Apply layer-wise LR decay: early layers at 1/100 of the head LR; final backbone layers at 1/10; head at base LR. A simpler version: freeze the first 2/3 of the backbone and only fine-tune the last 1/3 plus the head. Start with feature extraction for 1-2 epochs then gradually unfreeze.
Trap Using custom normalisation instead of the model's built-in preprocessing

Pretrained ResNet weights expect inputs normalised with ImageNet mean=[0.485,0.456,0.406] and std=[0.229,0.224,0.225]. Applying a different normalisation (e.g., mean=0, std=1 or no normalisation) shifts the input distribution away from what the pretrained features expect. Early layers have learned to respond to specific input ranges — wrong normalisation causes the feature extractor to output garbage, producing poor accuracy even with a perfectly trained head.

Fix Always use weights.transforms() from the torchvision weights object — it contains the exact preprocessing used during pretraining including resize, crop, and normalisation. Never hardcode normalisation values; retrieve them from the weights object: ResNet50_Weights.IMAGENET1K_V2.transforms().

The decision matrix: (1) Small dataset + similar domain (e.g., ImageNet-pretrained on medical images of similar texture): feature extraction. The risk of overfitting during fine-tuning exceeds the gain from adapting features. (2) Large dataset + similar domain: fine-tuning with low LR. Pretrained features are a good starting point; you can refine them without catastrophic forgetting. (3) Large dataset + very different domain (e.g., satellite imagery, X-rays, industrial defects): fine-tune ALL layers with a cosine schedule. The domain gap means pretrained high-level features may not transfer well, but pretrained low-level features (edges, textures) still help as initialisation. (4) Very large domain-specific dataset: consider training from scratch — the cost of pretraining on domain data may be lower than the mismatch from ImageNet features.

Catastrophic forgetting: when fine-tuning a pretrained model, large gradient updates overwrite the pretrained weights with task-specific adjustments, destroying previously learned general representations. Early layers (which encode reusable features like edges and textures) are corrupted first because their gradients from the new task are large relative to their learned values. Gradual unfreezing (Howard & Ruder, 2018): start by training only the classification head (backbone frozen). Then unfreeze the last backbone block and train for a few epochs. Then unfreeze the second-to-last block. Continue until the full network is unfrozen. This gives each layer time to adapt incrementally — by the time a layer is unfrozen, the downstream layers have already adapted and gradient signals are more targeted. Layer-wise LR decay is complementary: lower LR in early layers = smaller updates = less forgetting.

The convolution operation encodes two inductive biases simultaneously: translation equivariance (same kernel everywhere) and locality (each output depends on a bounded receptive field). These biases make CNNs sample-efficient for vision tasks. Transfer learning works because these biases produce reusable features — edges, textures, shapes — that generalise across datasets.
08

Transformer Architecture — Attention, BERT vs GPT & Flash Attention

You need to choose between BERT and GPT as the backbone for a new text classification system. Both are "Transformers," but they differ fundamentally in what each token can attend to and how they were pretrained. Understanding the architecture at the attention-mask level is what separates engineers who choose the right model from engineers who copy a tutorial without knowing why.

Scaled Dot-Product Attention, Multi-Head Attention & Positional Encoding
Attention(Q,K,V)=softmax(QKᵀ/√d_k)·V; why √d_k scaling prevents softmax saturation; MHA runs h parallel heads; sinusoidal vs learned vs RoPE positions

Scaled Dot-Product Attention: Inputs: Q (queries), K (keys), V (values) — shape (seq_len, d_k) each Attention(Q, K, V) = softmax(QKᵀ / √d_k) · V QKᵀ: (seq_len, seq_len) similarity matrix — every pair of positions /√d_k: scaling to prevent softmax saturation softmax: attention weights — each row sums to 1 ·V: weighted average of values Why √d_k? If q,k ~ N(0,1): qᵀk ~ N(0, d_k). Without scaling, variance grows with d_k. At large d_k, dot products are large → softmax saturates → near-one-hot output. Near-one-hot softmax = near-zero gradient (saturation kills learning). /√d_k normalises variance back to O(1). Multi-Head Attention (MHA): Run h heads in parallel, each with smaller d_k = d_model/h: headᵢ = Attention(QWᵢᴼ, KWᵢᴷ, VWᵢᵛ) Wᵢ are learned projections MHA(Q,K,V) = Concat(head₁,...,headₕ) · Wᴼ Each head attends to different positions/patterns simultaneously. Positional Encoding: Self-attention is permutation-equivariant — order is ignored without PE. Sinusoidal (Vaswani 2017): PE[pos,2i] = sin(pos/10000^(2i/d_model)) Learned absolute (BERT, GPT-2): nn.Embedding(max_len, d_model) RoPE (LLaMA, Mistral): rotary — relative positions, no max_len limit

  Scaled dot-product attention (seq_len=3, d_k=4):

  Q=(3,4)  K=(3,4)  V=(3,4)

    Step 1: QKᵀ / √d_k         Step 2: softmax       Step 3: × V
  ┌──────────────────┐      ┌──────────────────┐     ┌──────────┐
  │ q₁·k₁  q₁·k₂  q₁·k₃│  │ a₁₁  a₁₂  a₁₃  │     │ output₁  │
  │ q₂·k₁  q₂·k₂  q₂·k₃│→ │ a₂₁  a₂₂  a₂₃  │  →  │ output₂  │
  │ q₃·k₁  q₃·k₂  q₃·k₃│  │ a₃₁  a₃₂  a₃₃  │     │ output₃  │
  └──────────────────┘      └──────────────────┘     └──────────┘
   similarity scores       rows sum to 1 (softmax)   weighted V

  Complexity: O(n²·d_k) time,  O(n²) memory — bottleneck at long sequences.
  Flash Attention: same result, O(n) memory via SRAM tiling.
PyTorch — attention from scratch, √d_k scaling demo, nn.MultiheadAttention, F.scaled_dot_product_attention
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# ── Attention from scratch ────────────────────────────────
def attn(Q, K, V, mask=None):
    d_k    = Q.size(-1)
    scores = Q @ K.transpose(-2, -1) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(~mask, float('-inf'))
    return F.softmax(scores, dim=-1) @ V

B, T, d_k = 2, 8, 64
Q = torch.randn(B, T, d_k)
K = torch.randn(B, T, d_k)
V = torch.randn(B, T, d_k)
out = attn(Q, K, V)
print(f"Attention: Q={Q.shape} → {out.shape}")

# ── Why √d_k matters: softmax saturation ─────────────────
print("\nSoftmax saturation without √d_k scaling:")
for d in [4, 16, 64, 256]:
    q_ = torch.randn(100, d)
    k_ = torch.randn(100, d)
    scores_raw    = q_ @ k_.T
    scores_scaled = scores_raw / math.sqrt(d)
    # Entropy of softmax: lower = more saturated
    def entropy(s):
        p = F.softmax(s, dim=-1)
        return -(p * p.log().clamp(-20)).sum(dim=-1).mean().item()
    print(f"  d_k={d:4d}: raw entropy={entropy(scores_raw):.2f}, scaled={entropy(scores_scaled):.2f}")

# ── nn.MultiheadAttention ─────────────────────────────────
d_model, nhead = 256, 8
mha = nn.MultiheadAttention(d_model, nhead, batch_first=True, dropout=0.0)
x   = torch.randn(4, 32, d_model)
out_mha, attn_w = mha(x, x, x)        # self-attention
print(f"\nMHA: {x.shape} → {out_mha.shape}")
print(f"Attn weights: {attn_w.shape}  rows sum: {attn_w[0,0].sum():.4f}")

# ── Causal mask for GPT decoder ───────────────────────────
T = 8
causal = torch.tril(torch.ones(T, T, dtype=torch.bool))
print(f"\nCausal mask:\n{causal.int()}")

out_causal, _ = mha(x[:,:T], x[:,:T], x[:,:T], attn_mask=~causal)
print(f"Causal MHA: {out_causal.shape}")

# ── PyTorch 2.0 Flash Attention dispatch ─────────────────
q = torch.randn(4, 8, 64, 32)   # (B, heads, T, d_k)
k = torch.randn(4, 8, 64, 32)
v = torch.randn(4, 8, 64, 32)
out_fa = F.scaled_dot_product_attention(q, k, v, is_causal=False)
print(f"\nF.scaled_dot_product_attention: {q.shape} → {out_fa.shape}")
print("(Auto-dispatches to Flash Attention on CUDA — O(n) memory)")
Trap Forgetting to apply the causal mask in a decoder, causing information leakage

Without a causal mask, each token at position i can attend to tokens at positions i+1, i+2, ... (future tokens). During training the model sees the next token directly in its attention context and learns to copy rather than predict. At inference, future tokens do not exist — the model gets garbage inputs and collapses. The symptom: training loss near 0, inference output is random or repetitive.

Fix In any GPT-style decoder: pass is_causal=True to F.scaled_dot_product_attention, or attn_mask=~torch.tril(torch.ones(T,T)) to nn.MultiheadAttention. Verify the mask by printing the attention weight matrix for a sequence — the upper triangle should be 0 after softmax.
Trap Not applying the key_padding_mask for padded sequences in batches

When batching sequences of different lengths, shorter sequences are padded with [PAD] tokens. Without a key_padding_mask, the model attends to these padding positions and computes attention weights over them. Padding tokens carry no meaning and attending to them dilutes the signal, reduces effective context, and can cause the model to learn spurious patterns tied to padding positions.

Fix Pass key_padding_mask to nn.MultiheadAttention: a bool tensor of shape (batch, seq_len) where True indicates padding positions to ignore. With F.scaled_dot_product_attention: pass key_padding_mask expanded to (batch, heads, seq, seq). Always create the padding mask from the tokeniser's attention_mask output.

If Q and K have entries drawn from N(0,1), the dot product qᵀk is a sum of d_k products of standard normals, giving variance d_k. As d_k grows (e.g., 64 in BERT-base), dot products have standard deviation √64 = 8. When softmax receives inputs with large magnitude (±8), it saturates — nearly all probability mass concentrates on the maximum element, producing a near-one-hot distribution. The gradient of softmax through a one-hot distribution is nearly 0 (the one-hot function is discontinuous). Scaling by 1/√d_k reduces the standard deviation back to 1, keeping softmax in its "soft" regime where gradients are large. In the limit d_k → ∞ without scaling: attention degenerates to hard-max argmax, which is not differentiable.

Sinusoidal (Vaswani 2017): PE[pos, 2i] = sin(pos/10000^(2i/d)), PE[pos, 2i+1] = cos(...). Fixed — no parameters to learn. Relative distances are encoded via trigonometric identities: PE[pos+k] can be expressed as a linear transform of PE[pos]. Generalises to sequences longer than those seen in training (extrapolation). Learned absolute (BERT, GPT-2): each position gets a learned embedding vector from nn.Embedding(max_len, d_model). More flexible than sinusoidal but cannot generalise beyond max_len — the embedding for position 513 does not exist if max_len=512. RoPE (Rotary Positional Embedding, Su et al 2021 — used in LLaMA, Mistral, GPT-NeoX): encodes position by rotating the query and key vectors in complex space. Key advantage: relative position information is preserved in the dot product QKᵀ without adding position to the embedding. Generalises to arbitrary lengths via interpolation or NTK scaling. Now the dominant approach in modern LLMs.

Encoder vs Decoder, BERT vs GPT Paradigm & Flash Attention
Bidirectional encoder (BERT/MLM) vs causal decoder (GPT/CLM); encoder-decoder (T5/BART); Flash Attention O(n) memory motivation

Transformer variants differ in their attention mask and pretraining objective. Transformer block (standard): x → [LN → MultiHeadAttention → Add] → [LN → FFN → Add] → output (Pre-LN ordering — more stable than original post-LN) FFN: Linear(d_model, 4·d_model) → GELU → Linear(4·d_model, d_model) BERT (encoder-only, bidirectional): All tokens attend to ALL other tokens — no mask. Pretrained with Masked Language Model (MLM): predict [MASK] tokens. Best for: classification, NER, QA, embeddings, retrieval. Model sees full context → rich representations but cannot generate. GPT (decoder-only, causal): Token i can only attend to tokens 0, 1, ..., i (lower-triangular mask). Pretrained with next-token prediction (CLM — Causal Language Modeling). Best for: generation, completion, instruction following, reasoning. Training objective and inference are identical → no train/test mismatch. Encoder-Decoder (T5, BART, original Transformer): Encoder: full bidirectional attention over source sequence. Decoder: causal self-attention + cross-attention to encoder output. Best for: translation, summarisation, seq2seq tasks. Flash Attention (Dao et al, 2022): Problem: standard attention materialises (seq, seq) matrix in GPU HBM → O(n²) memory. For seq=4096, d=512: 4096×4096×4 bytes per layer = 64MB per layer (×32 layers = 2GB). Solution: tile Q,K,V in SRAM, compute attention block-by-block without writing to HBM. Result: O(n) memory, 2-4× faster due to reduced HBM read/write.

PyTorch — full pre-LN Transformer block, BERT-style encoder, GPT causal decoder, size comparison
import torch
import torch.nn as nn
import torch.nn.functional as F

# ── Pre-LN Transformer encoder block ─────────────────────
class TransformerBlock(nn.Module):
    def __init__(self, d_model=256, nhead=8, d_ff=1024, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.attn  = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(d_model)
        self.ff    = nn.Sequential(
            nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(d_ff, d_model), nn.Dropout(dropout),
        )
    def forward(self, x, attn_mask=None, key_padding_mask=None):
        # Pre-LN: normalise BEFORE sublayer
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x),
                          attn_mask=attn_mask,
                          key_padding_mask=key_padding_mask)[0]
        x = x + self.ff(self.norm2(x))
        return x

# ── BERT-style encoder (bidirectional, full attention) ────
class BERTEncoder(nn.Module):
    def __init__(self, vocab=30522, d=256, h=8, n=6, max_len=512):
        super().__init__()
        self.token_emb = nn.Embedding(vocab, d)
        self.pos_emb   = nn.Embedding(max_len, d)   # learned position
        self.blocks    = nn.ModuleList([TransformerBlock(d, h) for _ in range(n)])
        self.norm      = nn.LayerNorm(d)
    def forward(self, ids):
        B, T = ids.shape
        pos  = torch.arange(T, device=ids.device).unsqueeze(0)
        x    = self.token_emb(ids) + self.pos_emb(pos)
        for blk in self.blocks:
            x = blk(x)                # no mask — full bidirectional attention
        return self.norm(x)

# ── GPT-style decoder (causal, lower-triangular mask) ─────
class GPTDecoder(nn.Module):
    def __init__(self, vocab=50257, d=256, h=8, n=6, max_len=1024):
        super().__init__()
        self.token_emb = nn.Embedding(vocab, d)
        self.pos_emb   = nn.Embedding(max_len, d)
        self.blocks    = nn.ModuleList([TransformerBlock(d, h) for _ in range(n)])
        self.norm      = nn.LayerNorm(d)
        self.lm_head   = nn.Linear(d, vocab, bias=False)
        self.lm_head.weight = self.token_emb.weight  # weight tying
    def forward(self, ids):
        B, T = ids.shape
        causal_mask = ~torch.tril(torch.ones(T, T, dtype=torch.bool, device=ids.device))
        pos = torch.arange(T, device=ids.device).unsqueeze(0)
        x   = self.token_emb(ids) + self.pos_emb(pos)
        for blk in self.blocks:
            x = blk(x, attn_mask=causal_mask)
        return self.lm_head(self.norm(x))   # (B, T, vocab) — next-token logits

# ── Compare ───────────────────────────────────────────────
bert = BERTEncoder(d=256, h=8, n=6)
gpt  = GPTDecoder(d=256, h=8, n=6)
print(f"BERT params: {sum(p.numel() for p in bert.parameters()):,}")
print(f"GPT  params: {sum(p.numel() for p in gpt.parameters()):,}  (fewer: weight tying)")

ids_b = torch.randint(0, 30522, (2, 32))
ids_g = torch.randint(0, 50257, (2, 16))
print(f"\nBERT: {ids_b.shape} → {bert(ids_b).shape}  (contextualised, bidirectional)")
print(f"GPT:  {ids_g.shape} → {gpt(ids_g).shape}  (next-token logits, causal)")

# ── Flash Attention memory comparison ─────────────────────
for seq in [512, 2048, 4096]:
    attn_mat_mb = seq * seq * 4 / 1e6   # float32 QKᵀ per head
    print(f"seq={seq:5d}: standard attention QKᵀ = {attn_mat_mb:.1f} MB/head "
          f"({'OOM risk' if attn_mat_mb > 100 else 'ok'})")
Trap Using BERT (bidirectional encoder) for text generation tasks

BERT's full bidirectional attention means every token sees all other tokens including future ones. This makes BERT excellent for understanding tasks (classification, NER) but fundamentally broken for generation — at inference time, you generate one token at a time and future tokens do not exist, so the model's assumptions are violated. Engineers who try to use BERT for text completion get incoherent outputs because the model was never trained to predict the next token autoregressively.

Fix For generation: use a GPT-family (causal decoder) model. For understanding: use a BERT-family (bidirectional encoder). For seq2seq: use an encoder-decoder (T5, BART). The pretraining objective (MLM vs CLM vs span corruption) matches the architecture — do not mix them.
Trap Materialising the full attention matrix for long-sequence tasks (seq_len > 2048)

Standard attention stores a (batch, heads, seq, seq) float tensor in GPU HBM. For seq=4096, batch=2, heads=32: 2×32×4096×4096×4 bytes = 4GB for the attention matrix alone. A 7B model with 32 layers would need 128GB of HBM just for attention — impossible on any consumer GPU or even most A100s.

Fix Use F.scaled_dot_product_attention() — it automatically dispatches to Flash Attention on CUDA, computing attention in SRAM tiles without materialising the full (seq,seq) matrix. O(n) memory instead of O(n²). For inference at seq>8192: additionally use sliding window attention (Mistral) or grouped-query attention (GQA) to further reduce KV cache size.

BERT uses bidirectional full attention: every token attends to every other token. It is pretrained with Masked Language Modeling — predict randomly masked tokens using the full left AND right context. This makes representations rich in contextual information but makes the model unsuitable for generation (it has always seen future context during training; at generation time that context is absent). GPT uses causal (unidirectional) attention: token i can only attend to positions 0 through i. It is pretrained with next-token prediction — predict the next token given only the left context. Training and inference use identical attention patterns, making GPT a natural fit for generation. The architecture choice mirrors the task: BERT is a reading comprehension model; GPT is a writing model. For downstream tasks: classification, NER, QA → BERT (fine-tune with [CLS] token representation); completion, generation, instruction following → GPT.

Standard attention writes the full (n, n) attention weight matrix to GPU HBM (high-bandwidth memory) for the softmax computation, requiring O(n²) memory. Flash Attention (Dao et al, 2022) uses the online softmax algorithm and kernel fusion to compute attention block-by-block in SRAM (faster but smaller on-chip memory) without ever materialising the full (n, n) matrix in HBM. For each tile of Q, it loads the corresponding tiles of K and V from HBM, computes partial attention scores, and accumulates the output — updating a running softmax normaliser. The final output is identical to standard attention (exact, not approximate). Memory: O(n·d) for the output instead of O(n²). Speed: 2-4× faster than standard attention because SRAM reads are 10-100× faster than HBM reads, and the number of HBM reads is reduced from O(n²) to O(n). Flash Attention 2 and 3 extend this with further optimisations for multi-query and grouped-query attention.

The Transformer's power comes from two orthogonal insights: (1) attention learns which positions to aggregate information from, with no locality constraint; (2) residual connections ensure gradient flow and allow depth. Every variant — BERT, GPT, T5, ViT — is a different configuration of attention masks, positional encodings, and layer ordering.
Training at Scale 09–10
09

Mixed Precision, Gradient Accumulation & Distributed Training

Your model trains fine on one GPU but crashes with OOM on a larger batch — or you need to scale to 8 GPUs without changing the loss landscape. Training at scale means solving memory, communication, and numerical stability simultaneously.

Mixed Precision (FP16/BF16) & Gradient Accumulation
torch.autocast + GradScaler halve GPU memory; gradient accumulation simulates large batches without OOM

Mixed precision trains using 16-bit floats for most ops, keeping an FP32 master copy for optimizer updates. FP16 vs BF16: FP16: sign(1) + exponent(5) + mantissa(10) — range ±65504; overflow risk with large activations BF16: sign(1) + exponent(8) + mantissa(7) — same range as FP32; no overflow on forward pass → BF16 preferred for LLM training (Ampere+); FP16 + GradScaler for older V100/T4 hardware torch.autocast wraps the forward pass: • casts matmul/conv to FP16 automatically; keeps softmax/loss in FP32 • GradScaler multiplies loss by scale ≈ 2¹⁶ before backward to prevent FP16 gradient underflow Gradient accumulation — when target batch size doesn't fit in memory: effective_batch = per_device_batch × grad_accum_steps Accumulate N mini-batch gradients before one optimizer.step() → same update quality as single large batch, no OOM

  GPU Memory Breakdown — 7B param model:

  ┌────────────────────────────────────────────────────┐
  │  FP32 full precision                               │
  │  Parameters:      7B × 4 bytes  =  28 GB          │
  │  Gradients:       7B × 4 bytes  =  28 GB          │
  │  AdamW m + v:     7B × 8 bytes  =  56 GB          │
  │  Activations:     varies (batch-dependent)         │
  │  TOTAL (excl. act):            ≈ 112 GB            │
  └────────────────────────────────────────────────────┘
  ┌────────────────────────────────────────────────────┐
  │  AMP (BF16 fwd + FP32 optimizer)                   │
  │  BF16 params:     7B × 2 bytes  =  14 GB          │
  │  BF16 grads:      7B × 2 bytes  =  14 GB          │
  │  FP32 master:     7B × 4 bytes  =  28 GB          │
  │  AdamW m + v:     7B × 8 bytes  =  56 GB          │
  │  TOTAL (excl. act):            ≈ 112 GB            │
  │  Win = Tensor Core throughput + smaller activations│
  └────────────────────────────────────────────────────┘
PyTorch — AMP with GradScaler and gradient accumulation loop
import torch, torch.nn as nn
from torch.cuda.amp import autocast, GradScaler

device = torch.device('cuda')
model  = nn.Sequential(nn.Linear(512, 2048), nn.ReLU(), nn.Linear(2048, 10)).to(device)
opt    = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
scaler = GradScaler()       # only for FP16; skip for BF16

GRAD_ACCUM = 4              # effective batch = batch_size * 4

for step, (x, y) in enumerate(dataloader):
    x, y = x.to(device), y.to(device)

    with autocast(dtype=torch.bfloat16):   # BF16: no scaler needed
        logits = model(x)
        loss   = nn.functional.cross_entropy(logits, y)
        loss   = loss / GRAD_ACCUM          # normalise before accumulate

    scaler.scale(loss).backward()           # scale for FP16; no-op for BF16 path

    if (step + 1) % GRAD_ACCUM == 0:
        scaler.unscale_(opt)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(opt)
        scaler.update()
        opt.zero_grad()

# Memory audit after first step:
print(torch.cuda.memory_summary(abbreviated=True))
Trap GradScaler with BF16 autocast

Using GradScaler (designed for FP16 underflow prevention) together with torch.bfloat16 autocast — scaler adds unnecessary complexity and can interfere with gradient magnitudes.

Fix BF16 has the same exponent range as FP32 — gradient underflow does not occur. Use autocast(dtype=torch.bfloat16) with no GradScaler. Only use GradScaler when dtype=torch.float16.
Trap Gradient accumulation without normalising loss

Calling loss.backward() N times without dividing loss by N — gradients are N× too large, causing unstable updates or divergence at step boundaries.

Fix Divide loss by GRAD_ACCUM_STEPS before backward: loss = loss / N. This keeps the accumulated gradient magnitude equivalent to a single large-batch step.

AMP keeps a FP32 master copy of weights for optimizer updates (to avoid precision loss in the weight update step), plus FP32 Adam first and second moments. Per-parameter memory: FP16 weights + FP16 grads + FP32 master + FP32 m + FP32 v = 2+2+4+4+4 = 16 bytes. Full FP32 training with Adam: 4+4+4+4 = 16 bytes. The weight memory is the same. The saving comes from: (1) activations stored in FP16 during the forward pass — 2× smaller, critical for large batch or long sequence; (2) Tensor Core utilisation — matrix ops are 2-4× faster in FP16/BF16, not just smaller.

Three checks in order: (1) Loss normalisation — confirm loss is divided by grad_accum_steps before backward. Without this, gradients are N× too large and the optimizer step is massive. (2) zero_grad timing — confirm optimizer.zero_grad() is called only at the accumulation boundary, not every mini-batch. Zeroing mid-accumulation discards partial gradients silently. (3) Scheduler stepping — confirm scheduler.step() is called at the outer step (after optimizer.step()), not at every mini-batch. Stepping N× too fast causes LR to spike or decay too aggressively. Log gradient norm at each inner and outer step to isolate which boundary triggers the explosion.

Distributed Training — DDP vs FSDP, Activation Checkpointing & torch.compile
DDP replicates full model; FSDP shards params+grads+optimizer across GPUs; checkpointing trades compute for activation memory

Two strategies for multi-GPU training — choose based on model size relative to single-GPU memory: DistributedDataParallel (DDP): • Each GPU holds a full model replica • Forward + backward run independently on each GPU • Gradient all-reduce (ring-allreduce) after backward: g_global = (1/N) Σᵢ gᵢ • Memory per GPU = full model params + grads + optimizer — unchanged vs single GPU FullyShardedDataParallel (FSDP): • Shards parameters, gradients, and optimizer states across N GPUs • Memory per GPU ≈ total / N for params+optimizer (plus communication buffers) • All-gather before each forward layer; reduce-scatter after backward • Enables training 70B+ parameter models across 8 × A100 80GB Activation Checkpointing: • Default: all intermediate activations stored for backward (O(L·batch·seq·d) memory) • With checkpointing: store only at checkpoint boundaries; recompute during backward • Memory saving: ≈ 30–40%; compute cost: +33% (one extra forward pass) torch.compile (PyTorch 2.0+): • Traces graph, fuses kernels, reduces Python interpreter overhead • Typical speedup: 15–30% training, up to 2× inference • mode='reduce-overhead' for training; 'max-autotune' for inference

PyTorch — torchrun DDP setup, FSDP wrapping, and activation checkpointing
# Run: torchrun --nproc_per_node=4 train.py
import os, torch, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
import torch.nn as nn

# ── 1. Process group init ─────────────────────────────────
dist.init_process_group(backend='nccl')
local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)

# ── 2. DDP (model fits on one GPU) ───────────────────────
model = nn.TransformerEncoder(
    nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True), num_layers=12
).to(local_rank)
model_ddp = DDP(model, device_ids=[local_rank], find_unused_parameters=False)

# ── 3. FSDP (model exceeds single GPU) ───────────────────
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
mp_policy = MixedPrecision(
    param_dtype=torch.bfloat16,
    reduce_dtype=torch.float32,
    buffer_dtype=torch.bfloat16,
)
model_fsdp = FSDP(
    model,
    sharding_strategy=ShardingStrategy.FULL_SHARD,   # ZeRO-3
    mixed_precision=mp_policy,
    device_id=local_rank,
)

# ── 4. Activation checkpointing on FSDP ──────────────────
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    checkpoint_wrapper, apply_activation_checkpointing
)
apply_activation_checkpointing(
    model_fsdp,
    check_fn=lambda m: isinstance(m, nn.TransformerEncoderLayer),
)

# ── 5. torch.compile (PyTorch 2.0+) ─────────────────────
compiled_model = torch.compile(model_fsdp, mode='reduce-overhead')

# ── 6. Memory report after first step ────────────────────
print(torch.cuda.memory_summary(device=local_rank, abbreviated=True))
Trap DDP find_unused_parameters=True on large models

Enabling find_unused_parameters=True adds a graph traversal after every backward pass to find parameters without gradients — 15–30% throughput overhead when all parameters are used.

Fix Set find_unused_parameters=False (DDP default). Only enable it for models with genuine conditional forward paths where some parameters are skipped in certain inputs.
Trap FSDP FULL_SHARD on a small model

Applying FULL_SHARD to a 1B model across 8 GPUs — each forward layer triggers an all-gather that dominates compute time, turning what should be fast into a communication-bound job.

Fix Use ShardingStrategy.SHARD_GRAD_OP (shards gradients and optimizer states only, not params) for models that fit per-GPU. Reserve FULL_SHARD for models that genuinely do not fit on one GPU with AMP + activation checkpointing.

DDP: memory per GPU = full model copy regardless of how many GPUs you add — adding GPUs increases throughput, not per-GPU memory. For a 7B model with AdamW: 4 params + 4 grads + 8 optimizer = 16 bytes/param × 7B = 112 GB/GPU — impossible on a single 80 GB A100. FSDP FULL_SHARD: memory per GPU ≈ (params + grads + optimizer) / N GPUs + communication buffers. 8 × A100 → ≈ 14 GB/GPU for a 7B model with BF16 + FP32 optimizer, making it feasible. FSDP becomes necessary when the model + optimizer states do not fit on one GPU even with AMP and activation checkpointing. For a 7B model in BF16: params (14 GB) + grads (14 GB) + optimizer (56 GB) = 84 GB — exceeds 80 GB A100, so FSDP is required.

Step 1: add torch.cuda.memory_summary() after the first forward and first backward to isolate where memory spikes — peak is usually during backward (activations + gradients both present). Step 2: use torch.cuda.memory._record_memory_history() and the PyTorch Memory Visualizer to see each allocation with its stack trace. Step 3: apply in order — gradient accumulation (reduce per-step activation memory), activation checkpointing (trade compute for activation memory), reduce batch size, switch to BF16 AMP (smaller activations). Step 4: if model parameters alone exceed GPU capacity, switch to FSDP. Step 5: profile communication/compute ratio with torch.profiler to confirm memory savings did not introduce a communication bottleneck.

Distributed training doesn't just speed things up — it changes numerical behaviour. Gradient all-reduce, mixed precision, and batch size scaling all interact. Understand them before you need them in a production fire.
10

Efficient Fine-tuning — LoRA, QLoRA & PEFT

You need to fine-tune a 7B LLM on a custom task but have only one GPU. Full fine-tuning needs 112 GB of VRAM. LoRA brings it to 40 GB; QLoRA to 12 GB — with over 95% of the performance.

LoRA — Low-Rank Adaptation of Large Language Models
Inject trainable rank-r bypass ΔW = BA alongside frozen W₀; train only B and A — 99% fewer parameters at similar quality

LoRA (Hu et al, 2021) freezes pretrained weights W₀ and adds a low-rank bypass: W_adapted = W₀ + ΔW = W₀ + B·A W₀ ∈ ℝ^(d×k) — frozen pretrained weights B ∈ ℝ^(d×r) — trainable, init ~ N(0, σ²) A ∈ ℝ^(r×k) — trainable, init = 0 → ΔW = 0 at step 0 r ≪ min(d,k) — rank, typically 8–64 Parameter reduction (d=k=4096, r=16): Full: 4096 × 4096 = 16,777,216 params LoRA: 4096×16 + 16×4096 = 131,072 params → 99.2% reduction Forward: h = W₀x + (α/r)·BAx [α = scaling hyperparameter] Inference: merge W_merged = W₀ + (α/r)·BA → zero latency overhead QLoRA (Dettmers et al, 2023): • Load W₀ in 4-bit NF4 (Normal Float 4) quantization • Dequantize to BF16 on-the-fly during forward pass • Train B and A in BF16 as usual • Memory: 4-bit base + BF16 adapters ≈ 3.5 GB for 7B model

  LoRA decomposition (d=4096, k=4096, r=16):

  Input x ──────────────────────────────────────────► × W₀ ──► h_base
     │                                               (frozen)      │
     │                                                            (+)──► h_out
     └──► × Aᵀ ──► rank-r bottleneck ──► × B ──► h_lora ────────►
          (16×4096)    (batch, 16)         (4096×16)

  W₀:  full rank 4096×4096  ← frozen, zero gradient stored
  B·A: rank-16  4096×4096   ← only B (4096×16) and A (16×4096) updated

  Trainable params per weight: 2 × 4096 × 16 = 131,072
  vs full fine-tune:                 4096 × 4096 = 16,777,216   (128× reduction)
PyTorch — PEFT LoRA and QLoRA setup with bitsandbytes
from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
from transformers import BitsAndBytesConfig
import torch

# ── LoRA on FP16/BF16 model ───────────────────────────────
model = AutoModelForCausalLM.from_pretrained(
    'meta-llama/Llama-2-7b-hf', torch_dtype=torch.bfloat16, device_map='auto'
)
lora_cfg = LoraConfig(
    r=16,
    lora_alpha=32,           # effective scale = alpha / r = 2.0
    target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'],
    lora_dropout=0.05,
    bias='none',
    task_type=TaskType.CAUSAL_LM,
)
model = get_peft_model(model, lora_cfg)
model.print_trainable_parameters()
# trainable params: 4,194,304 || all params: 6,742,609,920 || 0.062%

# ── QLoRA — 4-bit base + BF16 adapters ───────────────────
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,   # nested quantization of metadata
)
model_4bit = AutoModelForCausalLM.from_pretrained(
    'meta-llama/Llama-2-7b-hf',
    quantization_config=bnb_config,
    device_map='auto',
)
model_4bit = prepare_model_for_kbit_training(model_4bit)  # enable grad checkpointing
model_4bit = get_peft_model(model_4bit, lora_cfg)

# ── Merge adapter for zero-latency inference ───────────────
from peft import PeftModel
base   = AutoModelForCausalLM.from_pretrained('meta-llama/Llama-2-7b-hf')
peft_m = PeftModel.from_pretrained(base, './lora-checkpoint')
merged = peft_m.merge_and_unload()   # W = W₀ + BA; removes adapter overhead
merged.save_pretrained('./merged-7b')
Trap Wrong target_modules for the architecture

Applying LoRA only to q_proj and v_proj — skipping k_proj, o_proj, and MLP projections — may give mediocre results because key and output projections are not adapted.

Fix For instruction following: target q, k, v, o projections. For domain adaptation: also include gate_proj, up_proj, down_proj. List all linear layer names via model.named_modules() and pick by parameter count. More modules at low rank often beats fewer at high rank.
Trap Serving LoRA without merging

Running inference with active PEFT adapter hooks — requires loading base model + adapter and adds Python hook overhead to every forward pass, increasing latency by 5–15%.

Fix For production: merge with merge_and_unload() before serving. The merged model is architecturally identical to base — no adapter overhead. For multi-adapter serving (one base, many tasks), use vLLM's LoRA serving which hot-swaps adapters at the CUDA kernel level.

At step 0 we want ΔW = B·A = 0 so the model starts from pretrained weights — a known-good initialisation point. If both A and B were random, ΔW would be non-zero at step 0, corrupting the pretrained representations before any fine-tuning signal arrives. Initialising B=0 guarantees B·A=0 regardless of A. A is initialised randomly so that the gradient ∂L/∂A = Bᵀ·(∂L/∂h_lora) is well-conditioned after the first update to B — giving both matrices a gradient signal from step 1 onward.

Start with r=16, lora_alpha=32 (scale=2.0). Too low (r=1–4): insufficient capacity to represent the distribution shift required by the fine-tuning task — loss plateaus early, fine-tuned behaviour is weak or inconsistent. Too high (r=256+): approaches full fine-tuning in parameter count; memory and compute savings disappear; risk of catastrophic forgetting if LR is not reduced. Empirical guidance: r=8–16 for instruction following or light domain adaptation; r=32–64 for heavy domain adaptation or tasks very different from pretraining. Run a rank sweep over {8, 16, 32, 64} on a small held-out set — the cost is negligible relative to full training. Also check that lora_alpha scales with r: alpha=2r is a safe default.

Adapter Layers, Prefix Tuning & the PEFT Library
Alternative PEFT methods: bottleneck adapters, prefix/prompt tuning, IA³; PEFT library unifies all under one API

LoRA is one of several parameter-efficient fine-tuning strategies — knowing the alternatives helps pick the right tool: Adapter Layers (Houlsby et al, 2019): Insert small bottleneck MLP after each transformer sub-layer: h → Linear(d→r) → GeLU → Linear(r→d) → residual add Only adapter weights trained; base weights frozen Disadvantage: adapters are sequential — increase latency even at inference (unlike LoRA which merges) Prefix Tuning (Li & Liang, 2021): Prepend L learnable "prefix" tokens to K and V at every attention layer: new_K = concat(P_K, K), new_V = concat(P_V, V) P ∈ ℝ^(L×d) Model weights fully frozen; only prefix tensors trained (~0.1% params) Con: prefix tokens consume context window; L=20 on a 512-token model = 4% context loss Prompt Tuning (Lester et al, 2021): Like prefix tuning but only at input embedding layer — simpler, fewer params Works well for large models (11B+) but underperforms smaller models vs LoRA IA³ (Infused Adapter by Inhibiting and Amplifying Inner Activations): Learns per-element rescaling vectors for keys, values, and FFN activations Fewest parameters of all PEFT methods — good for many-task few-shot settings PEFT Library (HuggingFace): Unified API: LoraConfig, PrefixTuningConfig, PromptTuningConfig, IA3Config model = get_peft_model(model, config) for any supported architecture save_pretrained saves only adapter weights (~50–200 MB vs 14 GB for full 7B)

PyTorch — PEFT library with LoRA, Prefix Tuning, and adapter hot-swap pattern
from peft import (
    get_peft_model, PeftModel,
    LoraConfig, PrefixTuningConfig, PromptTuningConfig,
    TaskType,
)
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

base_id = 't5-base'
model   = AutoModelForSeq2SeqLM.from_pretrained(base_id)

# ── Option A: LoRA (most common) ─────────────────────────
lora_model = get_peft_model(model, LoraConfig(
    r=8, lora_alpha=16, task_type=TaskType.SEQ_2_SEQ_LM
))
lora_model.print_trainable_parameters()   # ~0.3% of params

# ── Option B: Prefix Tuning ───────────────────────────────
prefix_model = get_peft_model(model, PrefixTuningConfig(
    num_virtual_tokens=20,
    task_type=TaskType.SEQ_2_SEQ_LM,
))

# ── Option C: Prompt Tuning ───────────────────────────────
prompt_model = get_peft_model(model, PromptTuningConfig(
    num_virtual_tokens=8,
    task_type=TaskType.SEQ_2_SEQ_LM,
    prompt_tuning_init_text='Summarise the following:',
))

# ── Save adapter only (~MBs): ────────────────────────────
lora_model.save_pretrained('./lora-summarise')
# ./lora-summarise/adapter_config.json   ← config
# ./lora-summarise/adapter_model.bin     ← weights only

# ── Load and hot-swap: ────────────────────────────────────
base   = AutoModelForSeq2SeqLM.from_pretrained(base_id)
loaded = PeftModel.from_pretrained(base, './lora-summarise')

# Merge for zero-latency inference:
merged = loaded.merge_and_unload()
Trap Prefix tuning consuming context window

Setting num_virtual_tokens=50 on a 512-token context model — 10% of context is consumed by prefix tokens that are invisible to the user but counted toward max length, degrading long-document task performance.

Fix Keep prefix length ≤ 5% of context window. Most task-specific signal comes from the first 10–20 prefix tokens — diminishing returns beyond that. Profile accuracy vs prefix length and pick the knee of the curve.
Trap task_type mismatch

Setting task_type=SEQ_2_SEQ_LM for a GPT-2/LLaMA causal model — PEFT applies the config to the wrong attention targets (e.g. cross-attention heads that don't exist in decoder-only models), producing incorrect adapters silently.

Fix Set task_type=CAUSAL_LM for decoder-only models (GPT, LLaMA, Mistral), SEQ_2_SEQ_LM for encoder-decoder (T5, BART). Always call model.print_trainable_parameters() and inspect the output to verify that the expected layers are targeted.

Adapter layers are sequential: the adapter MLP is inserted inside the transformer block and executes on every forward pass, even at inference — adding latency proportional to the adapter size. LoRA is parallel: B·A runs alongside W₀, but crucially, the adapter can be merged back: W_merged = W₀ + B·A. The merged model is architecturally identical to the base — no adapter overhead at inference, no latency penalty. This is LoRA's decisive advantage for production. Adapters cannot be merged without modifying the architecture permanently. Additionally, LoRA at r=8–16 achieves comparable or better quality than adapters at similar parameter counts, while being simpler to implement and serving-friendly.

Keep the 7B base model loaded once in GPU memory (14 GB for BF16). Store all 10 LoRA adapters as weight delta files (~50–200 MB each = 500 MB–2 GB total). At inference time, apply the relevant adapter dynamically: h = W₀x + BAx — this is how vLLM LoRA serving works. Do NOT pre-merge all adapters: that produces 10 separate 14 GB models (140 GB total). The per-request overhead of dynamic LoRA application is ~5–10% vs a merged model. For the dominant adapter (>70% of traffic): optionally pre-merge and serve as a separate endpoint for minimal latency. The tradeoff: merged = zero adapter overhead but requires a separate model slot; dynamic = 5% overhead but shares the base model memory across all tasks.

LoRA exploits the observation that weight updates during fine-tuning have low intrinsic rank — ΔW tends to be low-rank even when trained without the constraint. This is not an approximation; it reflects the geometry of fine-tuning.
Model Compression & MLOps 11–12
11

Model Compression — Quantization, Pruning & Export

Your 7B model is accurate but takes 14 GB of VRAM and 200 ms per inference. Production needs sub-8 GB and sub-50 ms. Model compression is the bridge between research accuracy and deployment economics.

Quantization — INT8, INT4, GPTQ & AWQ
Map FP32 weights to lower-bit integers using scale+zero_point; GPTQ/AWQ for near-lossless LLM 4-bit quantization

Quantization maps floating-point weights to integers, reducing memory and enabling faster integer arithmetic on hardware. Linear (affine) quantization: scale = (W_max − W_min) / (2^bits − 1) zero_point = round(−W_min / scale) W_quant = clamp(round(W / scale + zero_point), 0, 2^bits−1) W_dequant = (W_quant − zero_point) × scale Post-training quantization (PTQ): Dynamic INT8: quantize weights offline; compute activations in FP32 → fastest to apply Static INT8: calibrate activation ranges on ≥128 representative samples → better latency GPTQ (Frantar et al, 2022) — 4-bit, near-lossless: • Quantizes one weight at a time; uses Hessian H = 2XᵀX to compensate errors in remaining weights • Protects weights that influence high-variance activations • 4-bit GPTQ on LLaMA-2 7B: ≈0.5–1 perplexity point loss vs FP16 AWQ (Lin et al, 2023) — activation-aware: • Identifies "salient" channels (large activation magnitudes) and scales them before quantization • Reduces quantization error for the most important weights • Faster calibration than GPTQ; similar or better quality

  FP32 → INT8 quantization (symmetric, per-tensor):

  FP32 weights: [-1.8,  0.4, -0.1,  2.1, -0.9,  1.5]
  W_min = -1.8,  W_max = 2.1,  range = 3.9

  scale      = 3.9 / 255  ≈ 0.0153
  zero_point = round(1.8 / 0.0153)  = 118

  ┌──────────────────────────────────────────────────────┐
  │  FP32:     -1.8    0.4   -0.1    2.1   -0.9    1.5  │
  │  INT8:       0    144    111    255     59    216    │
  │              ↑                   ↑                   │
  │           min→0             max→255                  │
  │                                                      │
  │  Dequant: -1.807  0.396  -0.107  2.094  -0.900  1.500│
  │  Max err: ≈ 0.01  (0.5% of range — acceptable)      │
  └──────────────────────────────────────────────────────┘

  INT4 (16 bins): coarser grid → 3–8% error without calibration
  GPTQ compensates: after quantizing wᵢ, update remaining wⱼ using Hessian
  so cumulative error stays bounded across the full weight matrix
PyTorch — dynamic INT8 quantization and bitsandbytes 4-bit loading with ONNX export
import torch
import torch.nn as nn
from torch.quantization import quantize_dynamic

# ── 1. Dynamic INT8 (CPU inference, no calibration) ───────
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True), num_layers=6
        )
        self.head = nn.Linear(512, 10)

    def forward(self, x):
        return self.head(self.layers(x).mean(1))

model = Encoder().eval()

q_model = quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)

# Size comparison:
def size_mb(m):
    return sum(p.numel() * p.element_size() for p in m.parameters()) / 1e6

print(f"FP32: {size_mb(model):.1f} MB   INT8: {size_mb(q_model):.1f} MB")

# ── 2. 4-bit loading with bitsandbytes (GPU) ──────────────
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

bnb_cfg = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=torch.bfloat16,
)
llm_4bit = AutoModelForCausalLM.from_pretrained(
    'meta-llama/Llama-2-7b-hf', quantization_config=bnb_cfg, device_map='auto'
)
# Memory: FP16 ≈ 14 GB → 4-bit NF4 ≈ 3.5 GB

# ── 3. ONNX export ────────────────────────────────────────
dummy = torch.randn(1, 32, 512)   # (batch, seq, d_model)
torch.onnx.export(
    model, dummy, 'encoder.onnx',
    input_names=['input'], output_names=['logits'],
    dynamic_axes={'input': {0: 'batch', 1: 'seq'}, 'logits': {0: 'batch'}},
    opset_version=17,
)
# Convert to TensorRT (command line):
# trtexec --onnx=encoder.onnx --fp16 --saveEngine=encoder.trt
Trap Uniform INT4 quantization across all layers

Quantizing every layer — including the embedding table and first/last linear layers — to 4-bit. These boundary layers are most sensitive to quantization error and contribute disproportionately to accuracy loss.

Fix Keep first and last linear layers at INT8 or FP16 (GPTQ default behaviour). Profile per-layer quantization error on a calibration set: layers with high ‖W − W_q‖_F are the bottleneck. The top 5–10% of sensitive layers account for most accuracy degradation.
Trap ONNX export with data-dependent control flow

Exporting a model with Python-level if/for branches conditioned on tensor values — torch.onnx.export traces a single execution path and silently ignores branches not taken during the trace.

Fix Use torch.jit.script for data-dependent control flow: script captures the full Python logic as a typed IR. For ONNX: refactor the model to use torch.where, torch.clamp, and mask operations instead of Python if/else on tensor values.

PTQ: take a trained FP32 model, quantize weights offline with a calibration dataset — no retraining. Time: minutes. INT8 PTQ typically loses <1% accuracy; INT4 PTQ loses 2–5% on long-tail distributions without GPTQ/AWQ. QAT: simulate quantization during training using "fake quantize" ops (round then dequantize in FP32) so the model learns to be robust to quantization noise — gradients flow through via the straight-through estimator. Slower (full training run), but INT4 QAT can match INT8 PTQ quality. Use PTQ first — almost always sufficient for INT8, often sufficient for INT4 with GPTQ/AWQ. Only invest in QAT when PTQ accuracy is insufficient and training compute budget is available.

GPTQ uses the Optimal Brain Compression framework: it quantizes weights one at a time within a layer, and after quantizing weight w_j (introducing error δ_j), it updates all remaining unquantized weights in that row to compensate: ΔW_remaining = −δ_j · (H⁻¹)_:,j / (H⁻¹)_j,j where H = 2XᵀX is the Hessian of the layer's reconstruction loss. Weights with high Hessian diagonal (high influence on output) are quantized last — their errors have already been partially compensated by updates to other weights. The result: quantization error is redistributed across the weight matrix rather than accumulating, achieving 4-bit with <1 perplexity point loss on LLaMA where naive INT4 loses 3–5 points.

Pruning, Knowledge Distillation & Model Export
Magnitude pruning removes unimportant weights; KD trains a small student on teacher soft logits; TorchScript/ONNX for production deployment

Three compression techniques complementary to quantization: Pruning: Unstructured (magnitude): zero weights with |w| < threshold sparsity = (zero weights) / (total weights); 50–80% with <2% accuracy loss Con: sparse matrix on standard hardware ≠ faster — BLAS/cuBLAS is dense-optimised Structured: remove entire filters, channels, or attention heads → reduces actual FLOPS; 30–50% channel pruning → 1.5–2× inference speedup Knowledge Distillation (KD — Hinton et al, 2015): Train small student to match large teacher's soft output distribution: L_KD = α·CE(z_student, y_labels) + (1−α)·T²·KL(softmax(z_teacher/T) ‖ softmax(z_student/T)) T = temperature (typically 3–5) — softens teacher distribution T² scaling keeps gradient magnitude consistent regardless of T "Dark knowledge": soft logits reveal inter-class similarities the one-hot labels hide TorchScript — Python-free deployment: torch.jit.trace(model, example) — records one execution path; no branching torch.jit.script(model) — compiles full Python logic to typed IR Serialised .pt file runs in C++/Java without Python runtime ONNX (Open Neural Network Exchange): Framework-agnostic IR; export from PyTorch, run with ONNX Runtime or TensorRT ONNX Runtime: 2–3× vs PyTorch on CPU; TensorRT: 3–5× on GPU with INT8

PyTorch — magnitude pruning, knowledge distillation loop, and TorchScript export
import torch, torch.nn as nn, torch.nn.utils.prune as prune
import torch.nn.functional as F

# ── 1. Magnitude pruning ──────────────────────────────────
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.fc    = nn.Linear(64 * 7 * 7, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        return self.fc(x.flatten(1))

model = CNN()
# Prune 40% of lowest-magnitude weights in conv layers:
prune.l1_unstructured(model.conv1, name='weight', amount=0.4)
prune.l1_unstructured(model.conv2, name='weight', amount=0.4)

sparsity = (model.conv1.weight == 0).float().mean()
print(f"Conv1 sparsity: {sparsity:.1%}")

prune.remove(model.conv1, 'weight')   # make permanent
prune.remove(model.conv2, 'weight')

# ── 2. Knowledge distillation ─────────────────────────────
teacher = CNN(); teacher.eval()
student = CNN()
opt = torch.optim.Adam(student.parameters(), lr=1e-3)
T, alpha = 4.0, 0.5

for x, y in dataloader:
    with torch.no_grad():
        t_logits = teacher(x)                              # teacher soft targets

    s_logits = student(x)
    loss_ce  = F.cross_entropy(s_logits, y)               # task loss

    soft_t = F.softmax(t_logits / T, dim=-1)
    soft_s = F.log_softmax(s_logits / T, dim=-1)
    loss_kd = F.kl_div(soft_s, soft_t, reduction='batchmean') * (T ** 2)

    loss = alpha * loss_ce + (1 - alpha) * loss_kd
    loss.backward(); opt.step(); opt.zero_grad()

# ── 3. TorchScript export ─────────────────────────────────
model.eval()
scripted = torch.jit.script(model)          # full Python → typed IR
scripted.save('model_scripted.pt')

traced = torch.jit.trace(model, torch.randn(1, 1, 28, 28))
traced.save('model_traced.pt')

# Load without Python:
loaded = torch.jit.load('model_scripted.pt')
Trap Knowledge distillation at T=1

Not applying temperature (T=1) — teacher output is a near-one-hot distribution with all probability mass on the top class. KD loss is equivalent to cross-entropy with teacher predictions — the "dark knowledge" about inter-class similarity is lost.

Fix Use T=3–5 for most classification tasks. Higher T spreads teacher probability across related classes, teaching the student the teacher's learned similarity structure. Tune T on a validation set: too high (T>10) makes the distribution uniform, losing all signal.
Trap TorchScript failing on dynamic Python

torch.jit.script errors on Python constructs it cannot compile: dicts with mixed-type values, closures, list comprehensions with dynamic types, or calls to non-scripted functions.

Fix Annotate return types and argument types explicitly. Replace Python-level branching with torch.where, torch.clamp, and masked operations. Apply torch.jit.script to sub-modules first to isolate failures. Use torch.jit.trace for models with no data-dependent control flow — it is simpler and more compatible.

Standard GPU computation uses cuBLAS dense matrix multiplication kernels that are highly optimised for contiguous memory access patterns. A "sparse" weight matrix stored in dense format (most values are zero, but the tensor shape is unchanged) still executes a full dense matmul — the zero values are multiplied and added, consuming FLOPS and memory bandwidth exactly as if they were non-zero. Latency improvement from sparsity only materialises in three cases: (1) hardware with native sparsity support — Nvidia A100 2:4 structured sparsity (exactly 2 non-zeros per 4 values) is accelerated by Tensor Cores; (2) sparse tensor format + sparse kernel (cuSPARSE) — effective only at very high sparsity (>95%) due to format overhead; (3) structured pruning — removing entire channels/heads/layers produces a smaller dense matrix, which cuBLAS handles efficiently. For practical latency reduction without special hardware, prefer structured pruning or quantization.

Three-stage approach: (1) Teacher: use a fine-tuned BERT-large (no further training needed) — freeze it completely during distillation. (2) Student: initialise from DistilBERT or BERT-base rather than random — a pretrained student converges much faster and achieves better final quality. (3) Combined loss: L = α·CE(student_logits, labels) + (1−α)·T²·KL(soft_teacher, soft_student) + β·MSE(student_CLS, teacher_CLS) — the third term aligns intermediate [CLS] representations, which helps more than logit-only distillation for classification. Set T=4, α=0.5, β=0.01. Use LR=1e-5 (lower than standard fine-tuning since the student is already reasonable). Expected outcome: 95–97% of BERT-large accuracy at 40% of inference latency.

Quantization and pruning change the numerical representation — not the architecture. Done carefully, they lose less than 1% accuracy while achieving 2–4× speedup and 50–75% memory reduction.
12

Training Stability & MLOps

Your model trains for 3 days then diverges at step 15,000 with loss = NaN. Or two researchers run the identical config and get different results. Training stability and MLOps are what separate a research script from a reliable production training pipeline.

Gradient Debugging & Training Stability
Monitor per-layer gradient norms; detect NaN/Inf early; gradient clipping; diagnose loss spikes

Unstable training manifests in three failure modes — each has a distinct signature and fix: 1. Exploding gradients: Symptom: loss spikes suddenly (5–10× in one step), then NaN Signature: per-layer gradient norm spikes by orders of magnitude Cause: LR too high, a corrupted batch, AMP FP16 overflow, insufficient clipping Fix: clip_grad_norm_(params, max_norm=1.0); reduce LR; switch to BF16 2. Vanishing gradients: Symptom: loss decreases slowly then plateaus; early layers stop learning Signature: gradient norms near-zero for layers 1–N/2, normal for final layers Cause: sigmoid/tanh in deep networks, wrong init, no residual connections, no LayerNorm Fix: switch to ReLU/GELU, add residual connections, LayerNorm, He/Xavier init 3. NaN in loss or activations: Cause A: log(0) in cross-entropy — any logit at −inf → loss = NaN Fix: label_smoothing=0.1 prevents probabilities from reaching exactly 0 Cause B: FP16 overflow — activations exceed ±65504 Fix: switch to BF16, or check if GradScaler scale factor dropped to 1 Cause C: corrupted input — NaN/Inf in training data Fix: add torch.isnan(x).any() guard before forward pass Gradient norm monitoring: total_norm = sqrt(sum(p.grad.norm()^2) for all params) A stable run: total_norm ≈ 0.3–1.0 with ±20% variance Spike to 100+: investigate that batch; steady growth: reduce LR or increase clip

PyTorch — gradient norm logging with WandB, NaN detection, and gradient clipping
import torch, torch.nn as nn, wandb

wandb.init(project='dl-training', config={'lr': 1e-4, 'clip_norm': 1.0})

model = nn.TransformerEncoder(
    nn.TransformerEncoderLayer(d_model=256, nhead=8, batch_first=True), num_layers=6
).cuda()
opt = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)

for step, (x, y) in enumerate(dataloader):
    x, y = x.cuda(), y.cuda()

    # ── Guard: corrupted input ────────────────────────────
    if torch.isnan(x).any() or torch.isinf(x).any():
        print(f"Step {step}: corrupt input — skipping")
        continue

    logits = model(x)
    # label_smoothing=0.1 prevents log(0) NaN:
    loss = nn.functional.cross_entropy(logits.mean(1), y, label_smoothing=0.1)

    # ── Guard: NaN loss ───────────────────────────────────
    if torch.isnan(loss) or torch.isinf(loss):
        # trace the exact op that produced NaN (debug mode only):
        with torch.autograd.detect_anomaly():
            model(x).mean(1).cross_entropy(y).backward()
        break

    loss.backward()

    # ── Clip + log pre-clip norm ──────────────────────────
    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

    # ── Per-layer norms for collapse detection ────────────
    layer_norms = {
        name: p.grad.norm().item()
        for name, p in model.named_parameters()
        if p.grad is not None
    }

    opt.step(); opt.zero_grad()

    wandb.log({
        'train/loss':      loss.item(),
        'train/grad_norm': grad_norm.item(),
        'train/grad_norms': layer_norms,
    }, step=step)

wandb.finish()
Trap detect_anomaly left on in production

torch.autograd.detect_anomaly() adds a full Python stack trace capture to every gradient computation — 5–10× slowdown. Left on in a multi-day training run, it can double training time.

Fix Wrap detect_anomaly in a debug flag: if cfg.debug_anomaly: with detect_anomaly(): .... Enable for 1–2 batches around the suspect step only, then disable. Normal training: always off.
Trap Logging only scalar training loss

Training loss looks healthy (steadily decreasing) but model shows poor validation performance — early layers have vanishing gradients and only the last 1–2 layers are learning. Scalar loss does not reveal this.

Fix Log per-layer gradient norms to WandB as a grouped line plot. A healthy model has similar norm magnitude across layers (or slightly decreasing from output to input). Near-zero norms in early layers signal vanishing gradient — fix with residual connections, better init, or lower LR.

Step 1 — check AMP: if using FP16, gradient underflow or activation overflow is common. Log scaler.get_scale() — if it has dropped to 1 the scaler has given up. Switch to BF16 or investigate what caused scale reduction. Step 2 — inspect inputs at the failing step: add logging of x.max(), x.min(), x.isnan().any() for the 10 steps before step 15,000. A data pipeline bug (corrupted shard, all-zero image from a failed augmentation) is a common cause. Step 3 — enable torch.autograd.detect_anomaly() and re-run from the last checkpoint before step 15,000 — it prints the exact operation and tensor that first produced NaN. Step 4 — check loss function: cross-entropy without label smoothing with any logit at −inf gives NaN. Add label_smoothing=0.1. Step 5 — check gradient norm at step 14,999: if it spiked to 500+ before the NaN, the update was catastrophic. Reduce max_norm from 1.0 to 0.3, or reduce LR by 10×.

Five sources of non-determinism even with a fixed seed: (1) Incomplete seeding — must seed torch, torch.cuda (all devices), numpy, Python random, and set PYTHONHASHSEED. Seeding only torch misses numpy/random used in data augmentation. (2) cuDNN benchmark — cudnn.benchmark=True auto-selects the fastest algorithm, which can vary between runs. Set benchmark=False. (3) Non-deterministic CUDA kernels — atomicAdd in scatter/gather operations is inherently non-deterministic. Set torch.backends.cudnn.deterministic=True and torch.use_deterministic_algorithms(True). (4) DataLoader workers — parallel workers have independent RNG states seeded at fork time; without explicit worker_init_fn the augmentation order varies. Use worker_init_fn to seed each worker explicitly. (5) Distributed training — allreduce communication order depends on GPU timing; single-GPU and multi-GPU results will differ even with identical seeds.

WandB Sweeps, Experiment Tracking & Reproducibility
WandB for run tracking and automated hyperparameter search; seed strategy for full reproducibility; fault-tolerant checkpointing

Production ML requires experiments that are comparable, resumable, and reproducible across machines and time. WandB Experiment Tracking: wandb.init(project, config) → run object with unique ID wandb.log({'loss': val, 'acc': val}, step=step) → dashboard, tables, media wandb.Artifact: version datasets and model checkpoints with full lineage WandB Sweeps — automated hyperparameter search: Define search space + strategy: grid: exhaustive; only for ≤3 HPs with small discrete ranges random: better coverage for large spaces; efficiency plateaus after ~50 runs bayes: Gaussian Process surrogate; ~3–5× more efficient than random for ≤5 HPs Agent: pulls (HP config, run_id) from sweep controller; runs one experiment; repeats Early termination: Hyperband kills clearly underperforming runs early Full reproducibility checklist: torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) numpy.random.seed(seed) random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False torch.use_deterministic_algorithms(True) Cost: 5–20% compute overhead from deterministic algorithms Fault-tolerant checkpoint strategy: Save: best (by val_loss), latest (for resume), every N steps Alternating slots A/B — never overwrite sole checkpoint mid-write Resume: restore model, optimizer, scheduler, epoch, step, and dataloader seed

PyTorch — WandB sweep config with Bayes search, seed_everything, and fault-tolerant checkpointing
import wandb, torch, numpy as np, random, os
import torch.nn as nn

# ── 1. Full reproducibility ───────────────────────────────
def seed_everything(seed: int = 42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark     = False

seed_everything(42)

# ── 2. WandB sweep definition ─────────────────────────────
sweep_config = {
    'method': 'bayes',
    'metric': {'name': 'val/loss', 'goal': 'minimize'},
    'parameters': {
        'lr':           {'min': 1e-5,  'max': 1e-3, 'distribution': 'log_uniform_values'},
        'weight_decay': {'values': [0.0, 0.01, 0.1]},
        'dropout':      {'min': 0.0,   'max': 0.5},
        'batch_size':   {'values': [16, 32, 64]},
    },
    'early_terminate': {'type': 'hyperband', 'min_iter': 3},
}
sweep_id = wandb.sweep(sweep_config, project='dl-experiments')

# ── 3. Training function (called by sweep agent) ──────────
def train_sweep():
    with wandb.init() as run:
        cfg = run.config
        seed_everything(42)
        model = nn.Sequential(nn.Linear(128, 256), nn.ReLU(),
                               nn.Dropout(cfg.dropout), nn.Linear(256, 10)).cuda()
        opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
        for epoch in range(10):
            val_loss = 0.42   # placeholder — replace with real eval
            wandb.log({'val/loss': val_loss, 'epoch': epoch})

wandb.agent(sweep_id, function=train_sweep, count=20)

# ── 4. Fault-tolerant checkpointing ──────────────────────
CKPT = './checkpoints'
os.makedirs(CKPT, exist_ok=True)

def save_ckpt(model, opt, sched, epoch, step, val_loss, is_best):
    state = {
        'epoch': epoch, 'step': step, 'val_loss': val_loss,
        'model': model.state_dict(), 'opt': opt.state_dict(),
        'sched': sched.state_dict() if sched else None,
    }
    slot = 'A' if epoch % 2 == 0 else 'B'     # alternating — never lose both
    torch.save(state, f'{CKPT}/ckpt_{slot}.pt')
    if is_best:
        torch.save(state, f'{CKPT}/best.pt')

def load_ckpt(path, model, opt, sched):
    s = torch.load(path, map_location='cpu')
    model.load_state_dict(s['model'])
    opt.load_state_dict(s['opt'])
    if sched and s['sched']:
        sched.load_state_dict(s['sched'])
    return s['epoch'], s['step'], s['val_loss']
Trap DataLoader workers not seeded

Setting torch.manual_seed(42) but using num_workers > 0 in DataLoader — each worker is forked from the main process but has an independent numpy/random RNG state that is not controlled by your main seed.

Fix Pass worker_init_fn to DataLoader: def wif(id): np.random.seed(42+id); random.seed(42+id). Also pass generator=torch.Generator().manual_seed(42) to control the DataLoader's own shuffle seed.
Trap Bayes sweeps with too many hyperparameters

Running a Bayes sweep over 8+ HPs — the Gaussian Process surrogate model becomes unreliable in high-dimensional spaces; the agent degrades to near-random search efficiency.

Fix Bayesian sweeps work best for ≤5 HPs. For more: first run a fast random sweep (20–30 runs) to identify which HPs have the most impact via WandB's built-in parameter importance chart, then do a focused Bayes sweep on the top 3–5.

Check in order: (1) Incomplete seeding — confirm all four are seeded: torch.manual_seed, torch.cuda.manual_seed_all, np.random.seed, random.seed, and PYTHONHASHSEED env var. (2) cudnn.benchmark=True — auto-selects algorithms at runtime based on input size; disable with benchmark=False. (3) Non-deterministic CUDA ops — set cudnn.deterministic=True and torch.use_deterministic_algorithms(True); some ops raise RuntimeError when deterministic mode is on but have no deterministic alternative — those must be identified. (4) DataLoader workers unseeded — use worker_init_fn to seed each worker independently. (5) External RNG dependencies — augmentation libraries (Albumentations, torchvision), tokenizers, or sampling code with their own uncontrolled RNG. (6) Binary search approach: progressively re-introduce components (data loading, augmentation, model forward, backward) until the divergence point is identified.

Requirements: survive preemption at any moment; resume within minutes; never lose more than 30 minutes of training. Strategy: (1) Save every 30 minutes to cloud storage (GCS/S3), not local disk — local disk is lost on preemption. (2) Save: model.state_dict(), optimizer.state_dict(), scheduler.state_dict(), epoch, global_step, best_val_loss, config dict, and the DataLoader random seed at the current step so shuffling resumes from the correct position. (3) Alternating slots: save to ckpt_A.pt and ckpt_B.pt alternately — if a save is interrupted mid-write, slot B remains intact. Never overwrite both simultaneously. (4) Best checkpoint saved separately as best.pt — never overwrite it with a periodic checkpoint regardless of slot rotation. (5) WandB run ID: save to checkpoint state and use wandb.init(id=saved_run_id, resume="allow") on restart so the loss curve appears continuous in the dashboard. (6) On startup: check for existing checkpoint; if found, restore all states and log "Resumed from step X" to WandB.

Loss NaN is almost always gradient explosion, AMP overflow, or corrupted input data — not a model architecture bug. Check in that order. And reproducibility is not a nicety — it is the foundation of trustworthy ML experimentation.

Deep Learning Systems — Complete

From perceptrons to production.

12 stages across neural net foundations, training mechanics, regularisation, architectures, scale, and deployment.