Skip to content

TorchShapeFlow

TorchShapeFlow is a static, AST-based shape analyzer for PyTorch. It reads your Python source — no execution required — infers tensor shapes through your code, and reports mismatches as structured diagnostics.

from typing import Annotated
import torch
from torchshapeflow import Shape

class Net(nn.Module):
    def __init__(self):
        self.conv = nn.Conv2d(3, 8, 3, padding=1)
        self.linear = nn.Linear(8 * 32 * 32, 10)

    def forward(self, x: Annotated[torch.Tensor, Shape("B", 3, 32, 32)]):
        y = self.conv(x)      # inferred: [B, 8, 32, 32]
        z = y.flatten(1)      # inferred: [B, 8192]
        return self.linear(z) # inferred: [B, 10]
$ tsf check mymodel.py
mymodel.py: ok

$ tsf check broken.py
broken.py:9:9 error TSF1004 Invalid reshape.

Philosophy

Like Pydantic for data validation, TorchShapeFlow is annotation-first: you declare shape contracts on function parameters, and the analyzer verifies consistency. Without annotations, there is nothing to check — and that is by design. You opt in where it matters, starting with forward, and extend coverage incrementally. Symbolic dimensions ("B", "T", "D") are the primary mechanism; the analyzer verifies that operations are consistent without needing concrete sizes.

What it does

  • Reads Annotated[torch.Tensor, Shape(...)] annotations from function parameters
  • Propagates symbolic shapes through supported PyTorch operations
  • Emits diagnostics when shapes are incompatible
  • Provides hover-style shape facts for editor integration

Getting started

For contributors

  • Architecture — module map, analysis pipeline, Dim type system
  • Development — make targets, CI, how to add a new operator