28  Lab 12 — Sheaf Trainer

Anchor chapter: Chapter 12 — Backprop-Free Training, Batches, and Timescale Separation.

Goal. Implement the joint (cochain, parameter) gradient flow on the paraboloid \(f(x,y) = x^2 + y^2\) and compare convergence with vanilla SGD.

Implement Algorithm 12.4 end-to-end on the [2, 30, 1] paraboloid task. Train with sheaf-based updates and an SGD baseline, both with the same initialisation, batch size, and learning rate. Plot training loss vs epoch for both, measure wall-clock time, and confirm the \(1/\lambda_{\min}^{\text{free}}\) slowdown predicted by Thm. 12.5. Extend the lab with a batch-parallelism demonstration: run \(B\) fast-phase cochain solves concurrently (e.g., via joblib) and time the speedup.

This lab uses torch for the SGD baseline and for loading the [2, 30, 1] model. PyTorch is not yet available in Pyodide. Run locally with the Plan B notebook below.

Prefer a local Jupyter environment? Download lab-12-sheaf-trainer.ipynb

Install dependencies: pip install torch numpy matplotlib joblib

28.1 Setup

import time
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim

rng = np.random.default_rng(0)
torch.manual_seed(0)
plt.rcParams['figure.figsize'] = (6, 4)

28.2 1. Build the object

Algorithm 12.4 separates training into two timescales: a fast phase (cochain solve — given current weights, find the cochain \(c^*\) that minimises \(\|\delta_\sigma c - \tilde{b}\|^2\) for each training sample) and a slow phase (parameter update — use the solved cochains to compute the gradient of the residual energy \(R(\theta)\) and take one step in weight space). The fast phase is a single triangular back-substitution per sample; the slow phase is a standard gradient step. We implement both and compare to vanilla SGD on the paraboloid regression task \(f(x_1, x_2) = x_1^2 + x_2^2 - 2/3\).

# ── Paraboloid dataset ─────────────────────────────────────────────────────
N = 500
X_np = rng.uniform(-2, 2, (N, 2))
y_np = (X_np[:, 0] ** 2 + X_np[:, 1] ** 2 - 2/3).reshape(-1, 1)
X_t  = torch.tensor(X_np, dtype=torch.float32)
y_t  = torch.tensor(y_np, dtype=torch.float32)

# ── [2, 30, 1] ReLU network ────────────────────────────────────────────────
def make_net():
    net = nn.Sequential(nn.Linear(2, 30), nn.ReLU(), nn.Linear(30, 1))
    torch.manual_seed(0)
    return net

# ── Fast phase: cochain solve (NumPy, per sample) ──────────────────────────
def fast_phase(x_np, y_np, W1, b1, W2, b2):
    """
    Given input x and target y, solve the two-sided Dirichlet problem
    for free pre-activations z1.  Returns z1_star (n1,).
    """
    n1 = W1.shape[0]
    n2 = W2.shape[0]
    z1 = W1 @ x_np + b1
    sig = np.diag((z1 > 0).astype(float))

    # δ_free: (n1+n2) × n1 matrix with z1-columns from edges e1 and e2
    delta_free = np.vstack([np.eye(n1), -W2 @ sig])
    rhs = np.concatenate([W1 @ x_np + b1, y_np - b2])
    z1_star, _, _, _ = np.linalg.lstsq(delta_free, rhs, rcond=None)
    return z1_star, sig


def sheaf_param_grad(x_np, y_np, W1, b1, W2, b2):
    """
    Gradient of residual energy R = ½‖δ_free z1* - rhs‖² w.r.t. W2 and b2.
    Uses the solved cochain z1*.
    """
    n1 = W1.shape[0]; n2 = W2.shape[0]
    z1_star, sig = fast_phase(x_np, y_np, W1, b1, W2, b2)
    a1 = np.maximum(0.0, z1_star)   # post-ReLU activations at z1*
    # Predicted output from z1*: z2_hat = W2 a1 + b2
    z2_hat = W2 @ a1 + b2
    # Residual on edge e2: z2_hat - y_target
    r = z2_hat - y_np
    # Gradient: ∂R/∂W2 = r aᵀ,  ∂R/∂b2 = r
    grad_W2 = np.outer(r, a1)
    grad_b2 = r.copy()
    return grad_W2, grad_b2, float(0.5 * r @ r)

print("Fast phase / parameter gradient test:")
net_test = make_net()
W1_np = net_test[0].weight.detach().numpy()
b1_np = net_test[0].bias.detach().numpy()
W2_np = net_test[2].weight.detach().numpy()
b2_np = net_test[2].bias.detach().numpy()
gW2, gb2, R = sheaf_param_grad(X_np[0], y_np[0], W1_np, b1_np, W2_np, b2_np)
print(f"  R(θ; x₀, y₀) = {R:.4f}")
print(f"  grad_W2 shape: {gW2.shape},  grad_b2 shape: {gb2.shape}")

28.3 2. Verify a theorem / run an experiment

We train both the sheaf-based algorithm (gradient on \(R\) via the fast-phase solve) and vanilla SGD (backprop on MSE) for 200 epochs with matched learning rates and report the training loss and wall-clock time. Theorem 12.5 predicts a \(1/\lambda_{\min}(L_{\text{free}})\) slowdown for the sheaf-based method relative to SGD in the initial transient; we estimate \(\lambda_{\min}\) from the assembled \(L_{\text{free}}\) at the initial weights.

n_epochs  = 200
eta_sheaf = 5e-3   # sheaf learning rate (slow phase)
eta_sgd   = 5e-3   # SGD learning rate

# ── SGD baseline ───────────────────────────────────────────────────────────
net_sgd  = make_net()
opt_sgd  = optim.SGD(net_sgd.parameters(), lr=eta_sgd)
loss_sgd = []

t0 = time.time()
for epoch in range(n_epochs):
    opt_sgd.zero_grad()
    pred = net_sgd(X_t)
    loss = nn.functional.mse_loss(pred, y_t)
    loss.backward()
    opt_sgd.step()
    loss_sgd.append(loss.item())
t_sgd = time.time() - t0

# ── Sheaf trainer (slow-phase updates on W2, b2 only) ─────────────────────
net_sh = make_net()   # same init (torch.manual_seed(0))
W1_s   = net_sh[0].weight.detach().numpy().copy()
b1_s   = net_sh[0].bias.detach().numpy().copy()
W2_s   = net_sh[2].weight.detach().numpy().copy()
b2_s   = net_sh[2].bias.detach().numpy().copy()

loss_sheaf = []

t0 = time.time()
for epoch in range(n_epochs):
    epoch_loss = 0.0
    gW2_acc = np.zeros_like(W2_s)
    gb2_acc = np.zeros_like(b2_s)
    for i in range(N):
        gW2, gb2, R_i = sheaf_param_grad(X_np[i], y_np[i], W1_s, b1_s, W2_s, b2_s)
        gW2_acc += gW2; gb2_acc += gb2
        epoch_loss += R_i
    W2_s -= eta_sheaf * gW2_acc / N
    b2_s -= eta_sheaf * gb2_acc / N
    loss_sheaf.append(epoch_loss / N)
t_sheaf = time.time() - t0

# ── L_free estimate at initial weights ────────────────────────────────────
sample_lams = []
for i in range(50):
    xi = X_np[i]; z1 = W1_s @ xi + b1_s
    sig = np.diag((z1 > 0).astype(float))
    n1, n2 = W1_s.shape[0], W2_s.shape[0]
    delta_free = np.vstack([np.eye(n1), -W2_s @ sig])
    L_free = delta_free.T @ delta_free
    lam_min = np.linalg.eigvalsh(L_free)[0]
    sample_lams.append(lam_min)
lam_min_mean = np.mean(sample_lams)

# ── Plots ─────────────────────────────────────────────────────────────────
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].semilogy(loss_sgd,   label=f'SGD  ({t_sgd:.1f}s)')
axes[0].semilogy(loss_sheaf, label=f'Sheaf ({t_sheaf:.1f}s)', linestyle='--')
axes[0].set_xlabel('epoch'); axes[0].set_ylabel('MSE / residual energy')
axes[0].set_title('Loss vs epoch: SGD vs Sheaf Trainer')
axes[0].legend()

axes[1].hist(sample_lams, bins=15, color='steelblue', edgecolor='white')
axes[1].axvline(lam_min_mean, color='red', linestyle='--',
                label=f'mean λ_min = {lam_min_mean:.3f}')
axes[1].set_xlabel('λ_min(L_free)'); axes[1].set_ylabel('count')
axes[1].set_title('λ_min(L_free) across 50 training samples\n(governs slow-phase convergence)')
axes[1].legend()
plt.tight_layout(); plt.show()

print(f"SGD wall-clock:   {t_sgd:.2f}s,  final loss: {loss_sgd[-1]:.4f}")
print(f"Sheaf wall-clock: {t_sheaf:.2f}s,  final loss: {loss_sheaf[-1]:.4f}")
print(f"Mean λ_min(L_free): {lam_min_mean:.4f}")
print(f"Predicted slowdown factor 1/λ_min: {1/lam_min_mean:.1f}×")

28.4 Exercises

  1. Layer-wise updates. The implementation above only updates \(W_2\) and \(b_2\) (the output layer). Extend the slow phase to also update \(W_1\) and \(b_1\) by differentiating \(R\) with respect to those parameters (requires chain rule through the fast-phase solve). Compare convergence of the full sheaf trainer to the partial one.

  2. Batch parallelism. Use concurrent.futures.ThreadPoolExecutor (or joblib.Parallel) to run the fast-phase cochain solves for all \(N\) samples concurrently. Measure wall-clock time for batch sizes \(B \in \{1, 10, 50, 200\}\) and plot the speedup.

  3. Timescale separation. Theorem 12.5 predicts a \(1/\lambda_{\min}\) slowdown. Plot the ratio of initial loss-decrease rates (SGD vs sheaf) against \(1/\lambda_{\min}\) computed at the initial weights for networks with different hidden widths \(n_1 \in \{10, 30, 100\}\).

  4. Regularisation as sheaf augmentation. Adding \(\ell_2\) regularisation \(\tfrac{\lambda}{2}\|\theta\|^2\) to the loss is equivalent to adding a “reluctance edge” from each weight vertex to a zero anchor. Implement this by modifying the fast-phase RHS and verify that the regularised sheaf trainer converges to a smaller weight norm.