TorchShapeFlow
TorchShapeFlow is a static, AST-based shape analyzer for PyTorch. It reads your Python source — no execution required — infers tensor shapes through your code, and reports mismatches as structured diagnostics.
from typing import Annotated
import torch
from torchshapeflow import Shape
class Net(nn.Module):
def __init__(self):
self.conv = nn.Conv2d(3, 8, 3, padding=1)
self.linear = nn.Linear(8 * 32 * 32, 10)
def forward(self, x: Annotated[torch.Tensor, Shape("B", 3, 32, 32)]):
y = self.conv(x) # inferred: [B, 8, 32, 32]
z = y.flatten(1) # inferred: [B, 8192]
return self.linear(z) # inferred: [B, 10]
$ tsf check mymodel.py
mymodel.py: ok
$ tsf check broken.py
broken.py:9:9 error TSF1004 Invalid reshape.
Philosophy
Like Pydantic for data validation, TorchShapeFlow is annotation-first:
you declare shape contracts on function parameters, and the analyzer verifies
consistency. Without annotations, there is nothing to check — and that is by
design. You opt in where it matters, starting with forward, and extend
coverage incrementally. Symbolic dimensions ("B", "T", "D") are the
primary mechanism; the analyzer verifies that operations are consistent without
needing concrete sizes.
What it does
- Reads
Annotated[torch.Tensor, Shape(...)]annotations from function parameters - Propagates symbolic shapes through supported PyTorch operations
- Emits diagnostics when shapes are incompatible
- Provides hover-style shape facts for editor integration
Getting started
- Quickstart — install and run your first check
- Annotation syntax — how to annotate your tensors
- Supported operators — what is analyzed and what shapes are inferred
For contributors
- Architecture — module map, analysis pipeline, Dim type system
- Development — make targets, CI, how to add a new operator