Skip to content

Quickstart

Install

pip install torchshapeflow

Or, from source:

git clone https://github.com/Davidxswang/torchshapeflow
cd torchshapeflow
make install   # uv sync --extra dev

Annotate your tensors

TorchShapeFlow reads Annotated[torch.Tensor, Shape(...)] parameter annotations. String dimensions are symbolic; integer dimensions are constant:

from typing import Annotated
import torch
from torchshapeflow import Shape

def attention(
    q: Annotated[torch.Tensor, Shape("B", "H", "T", "D")],
    k: Annotated[torch.Tensor, Shape("B", "H", "T", "D")],
) -> torch.Tensor:
    scores = q.matmul(k.transpose(-2, -1))  # [B, H, T, T]
    return scores

Run the checker

tsf check path/to/mymodel.py

For machine-readable output (used by editor integrations):

tsf check path/to/mymodel.py --json

Check an entire directory:

tsf check src/

Example output

No errors:

All clean (1 file checked)

With errors:

broken.py:9:9 error TSF1004 Invalid reshape.
broken.py:17:12 warning TSF1006 Broadcasting incompatibility.
1 error, 1 warning in 1 file (1 file checked)

Exit code is 0 when no errors are found, 1 otherwise.

Use --verbose (or -v) to see per-file status for clean files:

tsf check src/ --verbose
mymodel.py: ok
utils.py: ok
All clean (2 files checked)

Try the bundled examples

tsf check examples/simple_cnn.py
tsf check examples/transformer_block.py
tsf check examples/error_cases.py --json

Workflow

TorchShapeFlow is opt-in: it only checks functions whose parameters carry Annotated[torch.Tensor, Shape(...)] annotations. A practical adoption path:

  1. Start with forward — annotate the main entry point of your module. This is where input shapes are known and where most shape bugs surface.

  2. Run tsf check — the analyzer will propagate shapes through the operations in that function and report any mismatches.

  3. Follow the warnings — diagnostics point to operations where a shape could not be verified. These are often calls to helper functions that lack annotations. Add annotations there next.

  4. Define a shape vocabulary — once you have several annotated functions, extract common shapes into a shared shapes.py using TypeAlias (see Type alias pattern). This keeps annotations short and consistent across your codebase.

  5. Annotate helper functions — adding parameter and return annotations to helpers enables cross-function shape inference. The analyzer unifies symbolic dimensions at each call site, catching mismatches that span module boundaries.

Each step is incremental — you get value from the first annotation, and coverage grows as you add more.

Next steps