28 Lab 12 — Sheaf Trainer
Anchor chapter: Chapter 12 — Backprop-Free Training, Batches, and Timescale Separation.
Goal. Implement the joint (cochain, parameter) gradient flow on the paraboloid \(f(x,y) = x^2 + y^2\) and compare convergence with vanilla SGD.
Implement Algorithm 12.4 end-to-end on the [2, 30, 1] paraboloid task. Train with sheaf-based updates and an SGD baseline, both with the same initialisation, batch size, and learning rate. Plot training loss vs epoch for both, measure wall-clock time, and confirm the \(1/\lambda_{\min}^{\text{free}}\) slowdown predicted by Thm. 12.5. Extend the lab with a batch-parallelism demonstration: run \(B\) fast-phase cochain solves concurrently (e.g., via joblib) and time the speedup.
This lab uses torch for the SGD baseline and for loading the [2, 30, 1] model. PyTorch is not yet available in Pyodide. Run locally with the Plan B notebook below.
Prefer a local Jupyter environment? Download lab-12-sheaf-trainer.ipynb
Install dependencies: pip install torch numpy matplotlib joblib
28.1 Setup
import time
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
rng = np.random.default_rng(0)
torch.manual_seed(0)
plt.rcParams['figure.figsize'] = (6, 4)28.2 1. Build the object
Algorithm 12.4 separates training into two timescales: a fast phase (cochain solve — given current weights, find the cochain \(c^*\) that minimises \(\|\delta_\sigma c - \tilde{b}\|^2\) for each training sample) and a slow phase (parameter update — use the solved cochains to compute the gradient of the residual energy \(R(\theta)\) and take one step in weight space). The fast phase is a single triangular back-substitution per sample; the slow phase is a standard gradient step. We implement both and compare to vanilla SGD on the paraboloid regression task \(f(x_1, x_2) = x_1^2 + x_2^2 - 2/3\).
# ── Paraboloid dataset ─────────────────────────────────────────────────────
N = 500
X_np = rng.uniform(-2, 2, (N, 2))
y_np = (X_np[:, 0] ** 2 + X_np[:, 1] ** 2 - 2/3).reshape(-1, 1)
X_t = torch.tensor(X_np, dtype=torch.float32)
y_t = torch.tensor(y_np, dtype=torch.float32)
# ── [2, 30, 1] ReLU network ────────────────────────────────────────────────
def make_net():
net = nn.Sequential(nn.Linear(2, 30), nn.ReLU(), nn.Linear(30, 1))
torch.manual_seed(0)
return net
# ── Fast phase: cochain solve (NumPy, per sample) ──────────────────────────
def fast_phase(x_np, y_np, W1, b1, W2, b2):
"""
Given input x and target y, solve the two-sided Dirichlet problem
for free pre-activations z1. Returns z1_star (n1,).
"""
n1 = W1.shape[0]
n2 = W2.shape[0]
z1 = W1 @ x_np + b1
sig = np.diag((z1 > 0).astype(float))
# δ_free: (n1+n2) × n1 matrix with z1-columns from edges e1 and e2
delta_free = np.vstack([np.eye(n1), -W2 @ sig])
rhs = np.concatenate([W1 @ x_np + b1, y_np - b2])
z1_star, _, _, _ = np.linalg.lstsq(delta_free, rhs, rcond=None)
return z1_star, sig
def sheaf_param_grad(x_np, y_np, W1, b1, W2, b2):
"""
Gradient of residual energy R = ½‖δ_free z1* - rhs‖² w.r.t. W2 and b2.
Uses the solved cochain z1*.
"""
n1 = W1.shape[0]; n2 = W2.shape[0]
z1_star, sig = fast_phase(x_np, y_np, W1, b1, W2, b2)
a1 = np.maximum(0.0, z1_star) # post-ReLU activations at z1*
# Predicted output from z1*: z2_hat = W2 a1 + b2
z2_hat = W2 @ a1 + b2
# Residual on edge e2: z2_hat - y_target
r = z2_hat - y_np
# Gradient: ∂R/∂W2 = r aᵀ, ∂R/∂b2 = r
grad_W2 = np.outer(r, a1)
grad_b2 = r.copy()
return grad_W2, grad_b2, float(0.5 * r @ r)
print("Fast phase / parameter gradient test:")
net_test = make_net()
W1_np = net_test[0].weight.detach().numpy()
b1_np = net_test[0].bias.detach().numpy()
W2_np = net_test[2].weight.detach().numpy()
b2_np = net_test[2].bias.detach().numpy()
gW2, gb2, R = sheaf_param_grad(X_np[0], y_np[0], W1_np, b1_np, W2_np, b2_np)
print(f" R(θ; x₀, y₀) = {R:.4f}")
print(f" grad_W2 shape: {gW2.shape}, grad_b2 shape: {gb2.shape}")28.3 2. Verify a theorem / run an experiment
We train both the sheaf-based algorithm (gradient on \(R\) via the fast-phase solve) and vanilla SGD (backprop on MSE) for 200 epochs with matched learning rates and report the training loss and wall-clock time. Theorem 12.5 predicts a \(1/\lambda_{\min}(L_{\text{free}})\) slowdown for the sheaf-based method relative to SGD in the initial transient; we estimate \(\lambda_{\min}\) from the assembled \(L_{\text{free}}\) at the initial weights.
n_epochs = 200
eta_sheaf = 5e-3 # sheaf learning rate (slow phase)
eta_sgd = 5e-3 # SGD learning rate
# ── SGD baseline ───────────────────────────────────────────────────────────
net_sgd = make_net()
opt_sgd = optim.SGD(net_sgd.parameters(), lr=eta_sgd)
loss_sgd = []
t0 = time.time()
for epoch in range(n_epochs):
opt_sgd.zero_grad()
pred = net_sgd(X_t)
loss = nn.functional.mse_loss(pred, y_t)
loss.backward()
opt_sgd.step()
loss_sgd.append(loss.item())
t_sgd = time.time() - t0
# ── Sheaf trainer (slow-phase updates on W2, b2 only) ─────────────────────
net_sh = make_net() # same init (torch.manual_seed(0))
W1_s = net_sh[0].weight.detach().numpy().copy()
b1_s = net_sh[0].bias.detach().numpy().copy()
W2_s = net_sh[2].weight.detach().numpy().copy()
b2_s = net_sh[2].bias.detach().numpy().copy()
loss_sheaf = []
t0 = time.time()
for epoch in range(n_epochs):
epoch_loss = 0.0
gW2_acc = np.zeros_like(W2_s)
gb2_acc = np.zeros_like(b2_s)
for i in range(N):
gW2, gb2, R_i = sheaf_param_grad(X_np[i], y_np[i], W1_s, b1_s, W2_s, b2_s)
gW2_acc += gW2; gb2_acc += gb2
epoch_loss += R_i
W2_s -= eta_sheaf * gW2_acc / N
b2_s -= eta_sheaf * gb2_acc / N
loss_sheaf.append(epoch_loss / N)
t_sheaf = time.time() - t0
# ── L_free estimate at initial weights ────────────────────────────────────
sample_lams = []
for i in range(50):
xi = X_np[i]; z1 = W1_s @ xi + b1_s
sig = np.diag((z1 > 0).astype(float))
n1, n2 = W1_s.shape[0], W2_s.shape[0]
delta_free = np.vstack([np.eye(n1), -W2_s @ sig])
L_free = delta_free.T @ delta_free
lam_min = np.linalg.eigvalsh(L_free)[0]
sample_lams.append(lam_min)
lam_min_mean = np.mean(sample_lams)
# ── Plots ─────────────────────────────────────────────────────────────────
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
axes[0].semilogy(loss_sgd, label=f'SGD ({t_sgd:.1f}s)')
axes[0].semilogy(loss_sheaf, label=f'Sheaf ({t_sheaf:.1f}s)', linestyle='--')
axes[0].set_xlabel('epoch'); axes[0].set_ylabel('MSE / residual energy')
axes[0].set_title('Loss vs epoch: SGD vs Sheaf Trainer')
axes[0].legend()
axes[1].hist(sample_lams, bins=15, color='steelblue', edgecolor='white')
axes[1].axvline(lam_min_mean, color='red', linestyle='--',
label=f'mean λ_min = {lam_min_mean:.3f}')
axes[1].set_xlabel('λ_min(L_free)'); axes[1].set_ylabel('count')
axes[1].set_title('λ_min(L_free) across 50 training samples\n(governs slow-phase convergence)')
axes[1].legend()
plt.tight_layout(); plt.show()
print(f"SGD wall-clock: {t_sgd:.2f}s, final loss: {loss_sgd[-1]:.4f}")
print(f"Sheaf wall-clock: {t_sheaf:.2f}s, final loss: {loss_sheaf[-1]:.4f}")
print(f"Mean λ_min(L_free): {lam_min_mean:.4f}")
print(f"Predicted slowdown factor 1/λ_min: {1/lam_min_mean:.1f}×")28.4 Exercises
Layer-wise updates. The implementation above only updates \(W_2\) and \(b_2\) (the output layer). Extend the slow phase to also update \(W_1\) and \(b_1\) by differentiating \(R\) with respect to those parameters (requires chain rule through the fast-phase solve). Compare convergence of the full sheaf trainer to the partial one.
Batch parallelism. Use
concurrent.futures.ThreadPoolExecutor(orjoblib.Parallel) to run the fast-phase cochain solves for all \(N\) samples concurrently. Measure wall-clock time for batch sizes \(B \in \{1, 10, 50, 200\}\) and plot the speedup.Timescale separation. Theorem 12.5 predicts a \(1/\lambda_{\min}\) slowdown. Plot the ratio of initial loss-decrease rates (SGD vs sheaf) against \(1/\lambda_{\min}\) computed at the initial weights for networks with different hidden widths \(n_1 \in \{10, 30, 100\}\).
Regularisation as sheaf augmentation. Adding \(\ell_2\) regularisation \(\tfrac{\lambda}{2}\|\theta\|^2\) to the loss is equivalent to adding a “reluctance edge” from each weight vertex to a zero anchor. Implement this by modifying the fast-phase RHS and verify that the regularised sheaf trainer converges to a smaller weight norm.