30 Lab 14 — Residual-Block Sheaf
Anchor chapter: Chapter 14 — Frontiers: Deeper Architectures, Scaling, Open Questions.
Goal. Construct the cellular sheaf for a toy ResNet block (path graph plus one chord), and identify the unitriangular failure that prevents the harmonic-extension interpretation.
Build the sheaf of a single-residual-block network on the chord-augmented path graph. Compute \(\delta_\Omega\) (tall \(7 \times 6\)), verify Prop. 14.8 numerically for random weights, and solve the pinned Dirichlet problem via a full linear solve (since back-substitution fails). Compare the output to forward(x) of the equivalent PyTorch module. Bonus: stack two residual blocks and investigate whether the full-rank property persists.
This lab imports torch to construct the reference ResNet forward pass and compare against the sheaf solution. PyTorch is not yet available in Pyodide. Run locally with the Plan B notebook below.
Prefer a local Jupyter environment? Download lab-14-residual-block-sheaf.ipynb
Install dependencies: pip install torch numpy matplotlib networkx
30.1 Setup
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import torch
import torch.nn as nn
rng = np.random.default_rng(0)
torch.manual_seed(0)
plt.rcParams['figure.figsize'] = (6, 4)30.2 1. Build the object
A residual block computes \(\text{out} = \text{ReLU}(W_2 \text{ReLU}(W_1 x + b_1) + b_2) + x\) (assuming matching dimensions). This introduces a chord edge from the input vertex (vertex 0) directly to the output vertex (vertex 3) in the path-graph picture. The chord carries the skip connection: its tail restriction map is the identity \(I\) (copying \(x\) as-is) and its head restriction map is also \(I\). The full graph has 4 vertices and 4 edges (3 path edges + 1 chord). With both input and output pinned, the free vertices are \(\{z^{(1)}, z^{(2)}\}\) (the two hidden layers, each \(\mathbb{R}^n\)), giving a \(7 \times 6\) coboundary matrix — no longer square, so the unitriangular structure and \(\det = 1\) argument break down.
# ── Architecture: single ResNet block [n, n, n, n] ─────────────────────────
# n=3 for a manageable concrete example
n = 3 # all dimensions equal (for skip connection to work)
class ResBlock(nn.Module):
def __init__(self, n):
super().__init__()
self.fc1 = nn.Linear(n, n)
self.fc2 = nn.Linear(n, n)
def forward(self, x):
out = torch.relu(self.fc1(x))
out = self.fc2(out)
return torch.relu(out + x) # skip connection
torch.manual_seed(42)
block = ResBlock(n)
W1 = block.fc1.weight.detach().numpy() # (n, n)
b1 = block.fc1.bias.detach().numpy()
W2 = block.fc2.weight.detach().numpy()
b2 = block.fc2.bias.detach().numpy()
# ── Vertices and edges ────────────────────────────────────────────────────
# Vertex 0: input z^(0) ∈ ℝⁿ (pinned to x)
# Vertex 1: hidden z^(1) ∈ ℝⁿ (free; pre-activation after fc1+ReLU)
# Vertex 2: hidden z^(2) ∈ ℝⁿ (free; pre-activation of fc2)
# Vertex 3: output z^(3) ∈ ℝⁿ (pinned to y or computed)
#
# Edges:
# e1 = (0,1): fc1 linear
# e2 = (1,2): ReLU selection (Σ1)
# e3 = (2,3): fc2 linear
# e4 = (0,3): skip (chord) — identity on both sides, edge stalk ℝⁿ
def build_resnet_delta_omega(x_in, W1, b1, W2, b2):
"""
Build δ_Ω for the ResNet sheaf.
Free cochains: (z1, z2) — two hidden layers, each ℝⁿ.
Pinned: z^(0)=x_in, z^(3) determined by forward pass (or solve target).
Returns delta_Omega of shape (4n, 2n) and boundary rhs of shape (4n,).
"""
n = W1.shape[0]
N_free = 2 * n # z1 and z2
N_edges = 4 * n # 4 edges, each stalk ℝⁿ
# Forward pass to get activation patterns
z1 = W1 @ x_in + b1 # pre-activation (before ReLU) at vertex 1
a1 = np.maximum(0.0, z1) # post-ReLU
sig1 = np.diag((z1 > 0).astype(float)) # Σ^(1)
z2 = W2 @ a1 + b2 # pre-activation at vertex 2
sig2 = np.diag((z2 > 0).astype(float)) # Σ^(2) for outer ReLU
z3 = np.maximum(0.0, z2) + x_in # output (skip + ReLU)
# δ_Ω: rows = (e1[n], e2[n], e3[n], e4[n]), cols = (z1[n], z2[n])
delta = np.zeros((N_edges, N_free))
# e1 (0→1): head=z1, tail=z0 (pinned)
# (δc)_e1 = z1 - W1 z0 - b1 → col for z1: +I, col for z0: absorbed in rhs
delta[0:n, 0:n] = np.eye(n) # head z1 coeff
# e2 (1→2): head=z2, tail=z1 via ReLU selection Σ1
# (δc)_e2 = z2 - Σ1 z1 (ReLU edge: restriction from z1 = Σ1)
delta[n:2*n, 0:n] = -sig1 # tail z1 coeff (Σ1)
delta[n:2*n, n:2*n] = np.eye(n) # head z2 coeff
# e3 (2→3): head=z3 (pinned), tail=z2 via W2
# (δc)_e3 = z3 - W2 z2 - b2 → col for z2: -W2, z3 absorbed in rhs
delta[2*n:3*n, n:2*n] = -W2 # tail z2 coeff
# e4 (0→3): chord/skip, head=z3 (pinned), tail=z0 (pinned)
# (δc)_e4 = z3 - z0 → no free-variable columns; all in rhs
# Boundary RHS (from pinned z0=x_in and z3=z3_fwd)
rhs = np.zeros(N_edges)
rhs[0:n] = W1 @ x_in + b1 # e1: W1 z0 + b1
rhs[n:2*n] = np.zeros(n) # e2: no boundary term (ReLU selection, tail pinned via path)
rhs[2*n:3*n] = z3 - b2 # e3: z3 (pinned output) - b2
rhs[3*n:4*n] = z3 - x_in # e4: z3 - z0
return delta, rhs, z3
x_test = rng.standard_normal(n)
delta, rhs, z3_fwd = build_resnet_delta_omega(x_test, W1, b1, W2, b2)
print(f"δ_Ω shape: {delta.shape} ({4*n} rows × {2*n} cols — tall, not square)")
print(f"rank(δ_Ω) = {np.linalg.matrix_rank(delta)} (full column rank = {2*n}?)")
# det is undefined for non-square; check singular values
svd = np.linalg.svd(delta, compute_uv=False)
print(f"Singular values: {np.round(svd, 3)}")
print(f"Smallest singular value: {svd[-1]:.4f} (> 0 → full column rank)")30.3 2. Verify a theorem / run an experiment
Since \(\delta_\Omega\) is tall (overdetermined), there is no triangular back-substitution. Instead we use the least-squares solve \(\hat{c} = (\delta_\Omega^\top \delta_\Omega)^{-1} \delta_\Omega^\top \tilde{b}\). If the network’s forward pass is consistent with the sheaf (Prop. 14.8), the residual \(\|\delta_\Omega \hat{c} - \tilde{b}\|\) should be zero. We verify this across 50 random inputs and plot the absolute error between the sheaf solution and PyTorch’s forward.
# ── Solve overdetermined system via least-squares ──────────────────────────
c_hat, res, _, _ = np.linalg.lstsq(delta, rhs, rcond=None)
z1_hat = c_hat[:n]; z2_hat_vec = c_hat[n:]
# Compare to PyTorch forward pass
with torch.no_grad():
x_t = torch.tensor(x_test, dtype=torch.float32)
z3_pt = block(x_t).numpy()
z1_pt = torch.relu(block.fc1(x_t)).numpy() # post-ReLU hidden 1
print("Sheaf (least-squares) solution vs PyTorch:")
print(f" z1 (ReLU output): sheaf={np.round(z1_hat,4)}, torch={np.round(z1_pt,4)}")
print(f" z3 (block output via pinning): {np.round(z3_fwd,4)}")
print(f" z3 PyTorch: {np.round(z3_pt,4)}")
print(f" Residual ‖δ_Ω ĉ - rhs‖: {np.linalg.norm(delta @ c_hat - rhs):.2e}")
# ── Sweep 50 random inputs ─────────────────────────────────────────────────
errors = []
resids = []
for _ in range(50):
xr = rng.standard_normal(n)
dr, rr, z3r = build_resnet_delta_omega(xr, W1, b1, W2, b2)
cr, _, _, _ = np.linalg.lstsq(dr, rr, rcond=None)
z3_pt = block(torch.tensor(xr, dtype=torch.float32)).detach().numpy()
errors.append(abs(z3r - z3_pt).max())
resids.append(np.linalg.norm(dr @ cr - rr))
fig, axes = plt.subplots(1, 2, figsize=(11, 4))
axes[0].semilogy(errors, 'o', ms=4, color='steelblue')
axes[0].set_xlabel('input index'); axes[0].set_ylabel('max |z3_sheaf - z3_torch|')
axes[0].set_title('Sheaf vs PyTorch output error (50 inputs)')
axes[0].axhline(1e-5, color='gray', linestyle='--', alpha=0.5)
axes[1].semilogy(resids, 's', ms=4, color='coral')
axes[1].set_xlabel('input index'); axes[1].set_ylabel('‖δ_Ω ĉ - rhs‖')
axes[1].set_title('Least-squares residual (should be ≈ 0)')
plt.tight_layout(); plt.show()
print(f"\nMax output error: {max(errors):.2e}")
print(f"Max LS residual: {max(resids):.2e}")
print(f"\nKey finding: δ_Ω is {4*n}×{2*n} — tall, NOT square.")
print("The unitriangular structure (and det=1) from plain MLPs is LOST.")
print("Least-squares still recovers the forward pass, but back-substitution fails.")
# ── Visualise the chord-augmented path graph ───────────────────────────────
G = nx.DiGraph()
G.add_edges_from([(0,1),(1,2),(2,3),(0,3)])
pos = {0:(0,0), 1:(1,0.3), 2:(2,0.3), 3:(3,0)}
fig2, ax2 = plt.subplots(figsize=(7, 3))
nx.draw(G, pos=pos, ax=ax2, with_labels=True, node_color='steelblue',
node_size=700, font_color='white', edge_color=['gray','gray','gray','coral'],
width=[2,2,2,3], arrows=True, arrowsize=20, connectionstyle='arc3,rad=0.2')
ax2.text(0.5, 0.1, 'fc1', ha='center', fontsize=9, color='gray')
ax2.text(1.5, 0.1, 'ReLU', ha='center', fontsize=9, color='gray')
ax2.text(2.5, 0.1, 'fc2', ha='center', fontsize=9, color='gray')
ax2.text(1.5, -0.35, 'skip (chord)', ha='center', fontsize=9, color='coral')
ax2.set_title('ResNet block: path graph + chord edge (orange)')
ax2.axis('off')
plt.tight_layout(); plt.show()30.4 Exercises
Two-block stack. Stack two ResNet blocks and build the sheaf on the corresponding graph (6 vertices, 7 edges: two path triples plus two chords). Count the free dimensions and the edge dimensions. Is \(\delta_\Omega\) still full column rank for generic weights?
Dimension mismatch. Change the ResNet architecture so the hidden dimension differs from the input dimension (e.g., input \(\mathbb{R}^2\), hidden \(\mathbb{R}^4\)). The skip connection now requires a projection matrix \(P : \mathbb{R}^2 \to \mathbb{R}^4\). How does this change the chord restriction map and the shape of \(\delta_\Omega\)?
Positive definiteness. For plain MLPs, \(L_{\text{free}} = \delta_\Omega^\top \delta_\Omega\) is always positive definite (Lemma 3.2). For the ResNet sheaf, compute \(\lambda_{\min}(\delta_\Omega^\top \delta_\Omega)\) for 50 random weight matrices. Is it always positive? Does it depend on the weight scale?
Open problem. Section 7 of the paper lists as an open problem: proving convergence of the sheaf heat equation for networks with skip connections. Simulate the state-dependent heat equation \(\dot{c} = -L_{\text{free}}(\sigma(c))\, c\) for the ResNet sheaf starting from a random cochain (with input pinned but output free). Does the Dirichlet energy decrease monotonically? Document any counterexamples or numerical failures.