Architecture
Analysis pipeline
TorchShapeFlow analyzes a target file in one statement-by-statement walk, with optional project-local indexing to resolve imported aliases and annotated helper signatures:
- Parse —
ast.parseconverts source text into an AST module (parser.parse_source). - Resolve aliases and function signatures — module-level type aliases are collected from
X = Annotated[...],X: TypeAlias = Annotated[...], and, on Python 3.12+ runtimes,type X = Annotated[...]. If aProjectIndexis present, project-localfrom ... import ...references are resolved first so imported aliases and annotated helper signatures can be used during analysis. - Collect module specs —
_collect_class_specswalks class__init__bodies to findnn.Linear,nn.Conv2d,nn.Embedding,nn.MaxPool2d,nn.AvgPool2d,nn.Sequential,nn.MultiheadAttention, and passthrough module assignments, recording their constructor arguments as spec values. - Seed shape environment — for each function (or
forwardmethod), annotated parameters are parsed viaparser.parse_tensor_annotationand added to the environmentenv: dict[str, Value]. - Propagate shapes —
_analyze_statementwalks the function body statement by statement. For each assignment,_eval_exprevaluates the right-hand side, dispatching to the appropriate rule function. Results are stored back intoenv. - Emit results — diagnostics and hover facts accumulate in a
ModuleContextand are returned as aFileReport.
Module map
| Module | Responsibility |
|---|---|
model.py |
All core data types (Dim variants, TensorShape, TensorValue, TensorTupleValue, LinearSpec, Conv2dSpec, PassthroughSpec, EmbeddingSpec, Pool2dSpec, SequentialSpec, MultiheadAttentionSpec, ModuleSpec, Value). Shape arithmetic: product_dim, quotient_dim, sum_dim, broadcast_shapes, batch_matmul_shape, normalize_index. |
annotations.py |
Public Shape class used in Annotated[Tensor, Shape(...)]. |
parser.py |
Parses Annotated[Tensor, Shape(...)] annotation AST nodes into TensorValue. Raises AnnotationParseError on malformed annotations. |
analyzer.py |
Main AST walker. Manages the shape environment, dispatches to rule functions, emits diagnostics via ModuleContext. |
index.py |
Project-local alias and annotated-function indexing (ProjectIndex, FuncSig, symbolic substitution for cross-file calls). |
diagnostics.py |
Diagnostic dataclass and Severity type alias ("error" \| "warning"). |
report.py |
FileReport (list of diagnostics + hover facts per file) and HoverFact (inferred shape at a source location). |
cli.py |
Typer CLI. tsf check runs the analyzer and formats output. tsf version prints the package version. |
rules/__init__.py |
Re-exports all public inference functions. |
rules/shape_ops.py |
Tensor/functional shape-operator inference. See Supported Operators for the canonical user-facing inventory. |
rules/broadcasting.py |
infer_binary_broadcast — wraps broadcast_shapes for element-wise ops. |
rules/linear.py |
infer_linear for nn.Linear. |
rules/conv2d.py |
infer_conv2d for nn.Conv2d. |
rules/embedding.py |
infer_embedding for nn.Embedding. |
rules/pool2d.py |
infer_pool2d for nn.MaxPool2d and nn.AvgPool2d. |
rules/indexing.py |
infer_subscript for tensor subscript and shape-tuple indexing. |
rules/common.py |
Shared AST helpers: int_from_ast, qualified_name, dim_from_value, tuple_index, spatial_output_dim. |
utils/paths.py |
collect_python_files — recursive .py file discovery. |
Dim type hierarchy
Dim (TypeAlias)
├── ConstantDim(value: int) — a fixed integer size, e.g. 32
├── SymbolicDim(name: str) — a named unknown size, e.g. "B"
├── ExpressionDim(expr: str) — a derived expression, e.g. "4*B" or "(B*C)/4"
└── UnknownDim(token: str) — explicitly unresolvable
Shape arithmetic returns ConstantDim when all operands are constant and ExpressionDim otherwise. Expressions are stored as strings and compared structurally.
Shape environment
The environment env: dict[str, Value] maps variable names to their inferred Value:
Value (TypeAlias)
├── TensorValue(shape: TensorShape, origin: str | None)
├── ShapeTupleValue(dims: tuple[Dim, ...]) — result of x.shape or x.size()
├── IntegerValue(value: int | None) — result of x.ndim or x.size(i)
├── TensorTupleValue(tensors: tuple[TensorValue, ...]) — result of chunk/split/MHA
│
│ ModuleSpec (TypeAlias — stored in module_specs and env)
├── LinearSpec(in_features, out_features) — nn.Linear
├── Conv2dSpec(in_channels, out_channels, kernel_size, stride, padding, dilation) — nn.Conv2d
├── PassthroughSpec() — shape-preserving modules (BatchNorm, ReLU, …)
├── EmbeddingSpec(embedding_dim) — nn.Embedding
├── Pool2dSpec(kernel_size, stride, padding, dilation) — nn.MaxPool2d / nn.AvgPool2d
├── SequentialSpec(specs: tuple[ModuleSpec, ...]) — nn.Sequential
└── MultiheadAttentionSpec(embed_dim, num_heads, batch_first) — nn.MultiheadAttention
Spec values are stored in module_specs (keyed by attribute name) when their constructor is parsed from __init__. When self.linear(x) is called, the analyzer looks up "linear" in module_specs, retrieves the spec, and calls the appropriate inference function. Module aliases (m = self.linear; m(x)) are also supported: spec values stored in env are looked up before falling through to func_sigs.
When an annotated function call is resolved through func_sigs, symbolic
dimensions in the callee signature are unified with the caller argument shapes
and substituted into the declared return shape. Imported helper functions are
handled the same way when they can be resolved through ProjectIndex.
Diagnostic codes
| Code | Severity | Trigger |
|---|---|---|
TSF1001 |
error | Annotation parse error (malformed Annotated or Shape) |
TSF1002 |
— | Reserved (not used) |
TSF1003 |
error | Incompatible matmul / bmm shapes |
TSF1004 |
error | Invalid reshape or flatten dimensions |
TSF1005 |
error | Invalid cat or stack dimensions or mismatched shapes |
TSF1006 |
error or warning | Broadcasting incompatibility (error when both dims are constant; warning when one or both are symbolic) |
TSF1007 |
error | nn.Linear, nn.Conv2d, or nn.MaxPool2d/AvgPool2d input shape mismatch |
TSF1008 |
error | Invalid permute, transpose, squeeze, unsqueeze, chunk, or movedim dimensions |
TSF1009 |
error | Return shape does not match the declared return type annotation |
TSF1010 |
error | Symbolic dim bound to conflicting values across call-site arguments |
TSF2001 |
warning | Unsupported tensor method or unresolvable method arguments — shape inference lost |
TSF2002 |
warning | Call to unannotated function with tensor arg — shape inference lost |
TSF2003 |
warning | Unresolvable module self.xxx — no spec inferred |
Adding a new operator
See Development → Adding a new operator. Operator behavior and support status should be documented only once in Supported Operators; this page describes the implementation structure, not the canonical support matrix.