23  Lab 07 — Build the Neural Sheaf

Anchor chapter: Chapter 7 — Building a Sheaf from a ReLU Network.

Goal. Given a torch.nn.Sequential with ReLU activations, programmatically construct the corresponding cellular sheaf on the path graph.

Construct the neural sheaf for an arbitrary \([n_0, n_1, \ldots, n_{k+1}]\) ReLU MLP as a data structure: vertex list, edge list, edge-type labels, restriction-map callables. Populate it from a trained PyTorch model, evaluate the coboundary \(\delta(\sigma)\) on a concrete input, and verify the dimensional coincidence \(\dim C^1 = \dim C^0 - n_0\) (Rem. 7.3). Includes a side-by-side visualisation of the PyTorch computation graph and the sheaf’s path-graph base.

This lab imports torch; PyTorch is not yet available in Pyodide. Run locally with the Plan B notebook below.

Prefer a local Jupyter environment? Download lab-07-build-neural-sheaf.ipynb

Install dependencies: pip install torch numpy matplotlib networkx

23.1 Setup

import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import torch
import torch.nn as nn

rng = np.random.default_rng(0)
torch.manual_seed(0)
plt.rcParams['figure.figsize'] = (6, 4)

23.2 1. Build the object

We extract weight matrices and biases from a torch.nn.Sequential and build a NeuralSheaf object that stores vertex stalks \(\mathcal{F}(\ell) = \mathbb{R}^{n_\ell}\) and the activation-pattern-dependent coboundary \(\delta_\sigma\). For a \([n_0, n_1, \ldots, n_{k+1}]\) network, edge \(e_\ell = (\ell{-}1, \ell)\) has stalk \(\mathbb{R}^{n_\ell}\) with restriction maps \(\mathcal{F}_{\ell \trianglelefteq e_\ell} = I_{n_\ell}\) (identity) and \(\mathcal{F}_{(\ell-1) \trianglelefteq e_\ell}(c) = W^{(\ell)} \Sigma^{(\ell-1)} c\) (ReLU-gated weight). The coboundary is lower-triangular with identity blocks on the diagonal, confirming Lemma 3.2: \(\det \delta_\Omega = 1\).

class NeuralSheaf:
    """
    Cellular sheaf on a path graph induced by a ReLU MLP.

    Parameters
    ----------
    dims : list[int]   layer widths [n0, n1, ..., n_{k+1}]
    weights : list of (n_{l}, n_{l-1}) arrays   weight matrices W^(1)...W^(k+1)
    biases  : list of (n_{l},) arrays           bias vectors b^(1)...b^(k+1)
    """
    def __init__(self, dims, weights, biases):
        self.dims    = dims
        self.weights = weights
        self.biases  = biases
        self.n_layers = len(dims) - 1   # number of weight layers

    def forward(self, x):
        """Run the network; return all pre-activations z^(0),...,z^(k+1)."""
        zs = [x]
        for ell in range(self.n_layers):
            W, b = self.weights[ell], self.biases[ell]
            a_prev = np.maximum(0.0, zs[-1]) if ell > 0 else zs[-1]
            zs.append(W @ a_prev + b)
        return zs   # z^(0) = x (no ReLU), z^(1)...z^(k+1) pre-activations

    def activation_pattern(self, zs):
        """Return list of diagonal Σ^(ℓ) matrices from pre-activations."""
        # Σ^(0) = I (input has no ReLU), Σ^(ℓ) for ℓ=1,...,k
        sigmas = [np.eye(self.dims[0])]
        for ell in range(1, self.n_layers):
            sigmas.append(np.diag((zs[ell] > 0).astype(float)))
        return sigmas

    def coboundary_omega(self, x):
        """
        Build δ_Ω (free rows+cols = z^(1),...,z^(k+1)) for input x.
        Returns the square (Σ n_ℓ, ℓ=1..k+1) × (same) matrix.
        """
        zs     = self.forward(x)
        sigmas = self.activation_pattern(zs)
        dims_free = self.dims[1:]          # n1, ..., n_{k+1}
        N = sum(dims_free)
        delta = np.zeros((N, N))
        # offsets into the free-cochain vector
        offsets = np.concatenate([[0], np.cumsum(dims_free)])

        for ell in range(self.n_layers):   # ell = 0..k: edge e_{ell+1}
            n_head = self.dims[ell + 1]    # dimension of head vertex
            r0 = offsets[ell]              # row/col start for this edge's stalk
            c_head = offsets[ell]          # col start for head vertex z^(ell+1)
            # Head vertex ℓ+1: +I_{n_head}
            delta[r0:r0+n_head, c_head:c_head+n_head] += np.eye(n_head)
            # Tail vertex ℓ (free if ell >= 1):
            if ell >= 1:
                W   = self.weights[ell]    # shape (n_head, n_tail)
                Sig = sigmas[ell]          # shape (n_tail, n_tail)
                c_tail = offsets[ell - 1]
                n_tail = self.dims[ell]
                delta[r0:r0+n_head, c_tail:c_tail+n_tail] -= W @ Sig

        return delta

    def boundary_rhs(self, x):
        """
        Right-hand side vector b̃ for the harmonic extension equation
        δ_Ω c_Ω = b̃, where b̃ accounts for the pinned input vertex z^(0)=x.
        """
        dims_free = self.dims[1:]
        N = sum(dims_free)
        rhs = np.zeros(N)
        # Only edge e_1 sees z^(0)=x: contribution = W^(1) z^(0) + b^(1)
        n1 = self.dims[1]
        rhs[:n1] = self.weights[0] @ x + self.biases[0]
        # Remaining edges have bias-only boundary term (z^(0) does not appear)
        offset = n1
        for ell in range(1, self.n_layers):
            n_ell = self.dims[ell + 1]
            rhs[offset:offset + n_ell] = self.biases[ell]
            offset += n_ell
        return rhs


# ── Instantiate a [2, 4, 1] network from a PyTorch Sequential ────────────
net = nn.Sequential(nn.Linear(2, 4), nn.ReLU(), nn.Linear(4, 1))

dims    = [2, 4, 1]
weights = [net[0].weight.detach().numpy(),
           net[2].weight.detach().numpy()]
biases  = [net[0].bias.detach().numpy(),
           net[2].bias.detach().numpy()]

sheaf = NeuralSheaf(dims, weights, biases)

x_test = np.array([0.5, -1.2])
zs     = sheaf.forward(x_test)
delta  = sheaf.coboundary_omega(x_test)
rhs    = sheaf.boundary_rhs(x_test)

print(f"Architecture: {dims}")
print(f"dim C^0 = {sum(dims)},  dim C^1 = {sum(dims[1:])} = dim C^0 - n_0 (Rem. 7.3)")
print(f"δ_Ω shape: {delta.shape}  (should be {sum(dims[1:])} × {sum(dims[1:])})")
print(f"det(δ_Ω) = {np.linalg.det(delta):.6f}  (Lemma 3.2: should be 1)")

23.3 2. Verify a theorem / run an experiment

We verify Remark 7.3 (\(\dim C^1 = \dim C^0 - n_0\)), Lemma 3.2 (\(\det \delta_\Omega = 1\)), and Proposition 3.4 (forward pass = harmonic extension) on a concrete input. The harmonic extension solves \(\delta_\Omega c_\Omega = \tilde{b}\) by triangular back-substitution; the result should match net(x) to floating-point precision. The path-graph visualisation annotates each vertex with its stalk dimension and each edge with its type.

# ── Verify Prop. 3.4: harmonic extension = forward pass ─────────────────
c_harm = np.linalg.solve(delta, rhs)
z2_harm = c_harm[4:]                           # output stalk

z2_torch = net(torch.tensor(x_test, dtype=torch.float32)).detach().numpy()
print(f"\nHarmonic extension output z₂: {z2_harm}")
print(f"PyTorch forward pass f(x):    {z2_torch}")
print(f"Max absolute error:           {abs(z2_harm - z2_torch).max():.2e}")

# ── Check across 100 random inputs ──────────────────────────────────────
errors = []
for _ in range(100):
    xr = rng.standard_normal(2)
    dr = sheaf.coboundary_omega(xr)
    br = sheaf.boundary_rhs(xr)
    cr = np.linalg.solve(dr, br)
    z_harm = cr[dims[1]:]
    z_torch = net(torch.tensor(xr, dtype=torch.float32)).detach().numpy()
    errors.append(abs(z_harm - z_torch).max())

print(f"\nOver 100 random inputs: max error = {max(errors):.2e}")
print(f"Prop. 3.4 confirmed numerically: {max(errors) < 1e-8}")

# ── Visualise the path graph with stalk annotations ─────────────────────
G = nx.path_graph(len(dims))
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

pos = {v: (v, 0) for v in G.nodes()}
ax = axes[0]
nx.draw(G, pos=pos, ax=ax, with_labels=False, node_color='steelblue',
        node_size=800, edge_color='gray', width=2)
for v, (px, py) in pos.items():
    ax.text(px, py + 0.15, f'ℓ={v}', ha='center', va='bottom', fontsize=10, fontweight='bold')
    ax.text(px, py - 0.18, f'ℝ^{dims[v]}', ha='center', va='top', fontsize=9, color='white',
            bbox=dict(boxstyle='round', facecolor='steelblue', alpha=0.7))
for u, v in G.edges():
    mx, my = (u + v) / 2, 0
    ax.text(mx, my + 0.08, f'e_{u+1}\nℝ^{dims[v]}', ha='center', fontsize=8, color='darkred')
ax.set_title(f'Neural sheaf path graph  [{",".join(map(str,dims))}]')
ax.axis('off')

# Discord per edge at the harmonic extension
ax2 = axes[1]
edge_discords = []
n_prev = 0
for ell in range(sheaf.n_layers):
    n_ell = dims[ell + 1]
    row_block = delta[n_prev:n_prev + n_ell, :]
    d_e = abs(row_block @ c_harm - rhs[n_prev:n_prev + n_ell])
    edge_discords.append(np.linalg.norm(d_e))
    n_prev += n_ell
ax2.bar([f'e_{i+1}' for i in range(sheaf.n_layers)], edge_discords, color='steelblue')
ax2.set_ylabel('‖discord‖'); ax2.set_title('Per-edge discord at harmonic extension\n(should be ≈ 0)')
plt.tight_layout(); plt.show()

23.4 Exercises

  1. Deeper network. Build a NeuralSheaf for a [2, 8, 8, 1] network. Verify that \(\det \delta_\Omega = 1\) and that Proposition 3.4 holds for 50 random inputs. How does the condition number of \(\delta_\Omega\) change with depth?

  2. Singular activation patterns. When a neuron is exactly at the switching boundary (\(z_j = 0\)), the activation pattern is ambiguous. Create a synthetic input that puts one neuron on the boundary and compare the two choices of \(\Sigma\) (0 and 1 for that neuron). Do both give valid harmonic extensions?

  3. Dimensional coincidence. Remark 7.3 states \(\dim C^1 = \dim C^0 - n_0\), which is what makes \(\delta_\Omega\) square. For a ResNet-style skip connection from layer 0 to layer 2, this dimensional coincidence fails. Construct such a network, show that \(\delta_\Omega\) is no longer square, and discuss what happens to the harmonic-extension interpretation.

  4. Sheaf from biases. In the harmonic extension formulation the bias \(b^{(\ell)}\) appears as a boundary-forcing term, not in \(\delta_\Omega\) itself. Reformulate the sheaf by introducing a “bias vertex” with stalk \(\mathbb{R}^1\) pinned to 1 and edges carrying \(b^{(\ell)}\) as restriction maps. Does this change \(\det \delta_\Omega\)?