Asi topological-dataloader
'Batched heterogeneous topological complex dataloaders for TopoModelX.
install
source · Clone the upstream repo
git clone https://github.com/plurigrid/asi
Claude Code · Install into ~/.claude/skills/
T=$(mktemp -d) && git clone --depth=1 https://github.com/plurigrid/asi "$T" && mkdir -p ~/.claude/skills && cp -r "$T/skills/topological-dataloader" ~/.claude/skills/plurigrid-asi-topological-dataloader && rm -rf "$T"
manifest:
skills/topological-dataloader/SKILL.mdsource content
Topological Dataloader Skill
Addresses: TopoModelX #325 (batch heterogeneous complexes) and #243 (dataloader blocking production)
Problem Statement
TopoModelX currently lacks efficient dataloaders for:
- Heterogeneous complexes - varying sizes within batches
- Mixed topologies - simplicial + cell + hypergraph in same dataset
- Streaming large datasets - memory-efficient loading
Architecture
┌─────────────────────────────────────────────────────────────┐ │ TopologicalDataLoader │ ├─────────────────────────────────────────────────────────────┤ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ │ │ Simplicial │ │ Cell │ │ Hypergraph │ │ │ │ Collator │ │ Collator │ │ Collator │ │ │ └──────┬──────┘ └──────┬──────┘ └──────┬──────┘ │ │ │ │ │ │ │ └────────────────┼────────────────┘ │ │ ▼ │ │ ┌───────────────────────┐ │ │ │ UnifiedBatchCollator │ │ │ │ (sparse block diag) │ │ │ └───────────────────────┘ │ └─────────────────────────────────────────────────────────────┘
Implementation
Core Collator
import torch from torch_geometric.data import Batch from toponetx import SimplicialComplex, CellComplex, CombinatorialComplex from typing import List, Union, Dict import scipy.sparse as sp TopoComplex = Union[SimplicialComplex, CellComplex, CombinatorialComplex] class TopologicalBatch: """Batched topological complexes with sparse adjacency.""" def __init__( self, complexes: List[TopoComplex], features: Dict[int, torch.Tensor], # rank -> features adjacencies: Dict[str, torch.sparse.Tensor], # "B_1", "L_0", etc. batch_indices: Dict[int, torch.Tensor], # rank -> batch assignment ): self.complexes = complexes self.features = features self.adjacencies = adjacencies self.batch_indices = batch_indices self.batch_size = len(complexes) def to(self, device: torch.device) -> 'TopologicalBatch': return TopologicalBatch( complexes=self.complexes, features={k: v.to(device) for k, v in self.features.items()}, adjacencies={k: v.to(device) for k, v in self.adjacencies.items()}, batch_indices={k: v.to(device) for k, v in self.batch_indices.items()}, ) def collate_simplicial( complexes: List[SimplicialComplex], max_rank: int = 2, ) -> TopologicalBatch: """Collate simplicial complexes into sparse batched format.""" features = {r: [] for r in range(max_rank + 1)} batch_idx = {r: [] for r in range(max_rank + 1)} offset = {r: 0 for r in range(max_rank + 1)} # Boundary matrices per complex boundaries = {f"B_{r}": [] for r in range(1, max_rank + 1)} laplacians = {f"L_{r}": [] for r in range(max_rank + 1)} for batch_i, sc in enumerate(complexes): for rank in range(max_rank + 1): cells = list(sc.skeleton(rank)) n_cells = len(cells) # Features (placeholder - use actual if available) feat = torch.zeros(n_cells, 1) # or sc.get_features(rank) features[rank].append(feat) batch_idx[rank].extend([batch_i] * n_cells) # Boundary operators if rank > 0: B = sc.incidence_matrix(rank=rank, signed=True) B_shifted = _shift_sparse(B, offset[rank-1], offset[rank]) boundaries[f"B_{rank}"].append(B_shifted) # Hodge Laplacian L = sc.hodge_laplacian_matrix(rank=rank) L_shifted = _shift_sparse(L, offset[rank], offset[rank]) laplacians[f"L_{rank}"].append(L_shifted) offset[rank] += n_cells # Stack features and convert to tensors stacked_features = {r: torch.cat(features[r], dim=0) for r in features} stacked_batch = {r: torch.tensor(batch_idx[r], dtype=torch.long) for r in batch_idx} # Block-diagonal sparse matrices stacked_adj = {} for key, mats in {**boundaries, **laplacians}.items(): if mats: stacked_adj[key] = _block_diag_sparse(mats) return TopologicalBatch( complexes=complexes, features=stacked_features, adjacencies=stacked_adj, batch_indices=stacked_batch, ) def _shift_sparse(mat: sp.spmatrix, row_off: int, col_off: int) -> sp.coo_matrix: """Shift sparse matrix indices for block-diagonal stacking.""" coo = mat.tocoo() return sp.coo_matrix( (coo.data, (coo.row + row_off, coo.col + col_off)), shape=(coo.shape[0] + row_off, coo.shape[1] + col_off) ) def _block_diag_sparse(mats: List[sp.spmatrix]) -> torch.sparse.Tensor: """Create block-diagonal sparse tensor from list of scipy sparse matrices.""" block = sp.block_diag(mats).tocoo() indices = torch.tensor([block.row, block.col], dtype=torch.long) values = torch.tensor(block.data, dtype=torch.float32) return torch.sparse_coo_tensor(indices, values, block.shape)
PyTorch DataLoader Integration
from torch.utils.data import Dataset, DataLoader class SimplicialDataset(Dataset): """Dataset of simplicial complexes.""" def __init__(self, complexes: List[SimplicialComplex], labels: List[int] = None): self.complexes = complexes self.labels = labels or [0] * len(complexes) def __len__(self): return len(self.complexes) def __getitem__(self, idx): return self.complexes[idx], self.labels[idx] def simplicial_collate_fn(batch): """Collate function for DataLoader.""" complexes, labels = zip(*batch) topo_batch = collate_simplicial(list(complexes)) return topo_batch, torch.tensor(labels) # Usage dataset = SimplicialDataset(complexes, labels) loader = DataLoader( dataset, batch_size=32, shuffle=True, collate_fn=simplicial_collate_fn, num_workers=4, ) for batch, labels in loader: batch = batch.to(device) # Forward pass through TopoModelX model output = model(batch.features, batch.adjacencies)
GF(3) Integration
def assign_trit_to_cells(complex: SimplicialComplex, seed: int) -> Dict[int, int]: """Assign balanced ternary trits to cells for GF(3) conservation.""" import hashlib trits = {} for rank in range(complex.dim + 1): for i, cell in enumerate(complex.skeleton(rank)): h = int(hashlib.sha256(f"{seed}:{rank}:{cell}".encode()).hexdigest()[:8], 16) trits[(rank, i)] = (h % 3) - 1 # {-1, 0, +1} # Verify conservation total = sum(trits.values()) assert total % 3 == 0, f"GF(3) violation: sum={total}" return trits
Path Complexes Extension
From TopoModelX #230:
class PathComplex: """Path complex for directed graph analysis. A path complex is a simplicial complex where simplices are directed paths. Generalizes to allow longer-range dependencies than 1-hop edges. """ def __init__(self, graph, max_path_length: int = 3): self.graph = graph self.max_length = max_path_length self._paths = self._enumerate_paths() def _enumerate_paths(self) -> Dict[int, List[tuple]]: """Enumerate all paths up to max_length.""" from networkx import all_simple_paths paths = {k: [] for k in range(self.max_length + 1)} for source in self.graph.nodes(): for target in self.graph.nodes(): for path in all_simple_paths(self.graph, source, target, cutoff=self.max_length): paths[len(path) - 1].append(tuple(path)) return paths def boundary_matrix(self, rank: int) -> sp.spmatrix: """Compute boundary operator for path complexes.""" # Boundary of path [v0, v1, ..., vn] is alternating sum of sub-paths if rank == 0: return sp.csr_matrix((0, len(self._paths[0]))) higher = self._paths[rank] lower = self._paths[rank - 1] lower_idx = {p: i for i, p in enumerate(lower)} rows, cols, data = [], [], [] for i, path in enumerate(higher): for j in range(len(path)): subpath = path[:j] + path[j+1:] if subpath in lower_idx: rows.append(lower_idx[subpath]) cols.append(i) data.append((-1) ** j) return sp.coo_matrix((data, (rows, cols)), shape=(len(lower), len(higher)))
Links
Commands
just topo-dataloader-demo # Run dataloader demonstration just topo-batch-test # Test batched complex loading just topo-path-complex # Path complex example
GF(3) Category: MINUS (Verification) | Addresses production-blocking issues