Supported Operators
Shape notation: uppercase strings are symbolic dimensions (B, T, D), integers are constants.
Tensor shape methods
reshape / view
Input: (*dims)
Output: (*requested) — at most one element of shape may be -1 (automatically inferred).
For fully-constant shapes, the total element count must be preserved; mismatches are reported as TSF1004.
permute
Input: (*dims) (rank N)
Output: (dims[order[0]], dims[order[1]], ...) — reorders all N axes.
Negative indices supported.
transpose
Input: (*dims)
Output: (*dims) with dim0 and dim1 swapped. Negative indices supported.
flatten
Input: (*dims)
Output: (*dims[:start], product(dims[start:end+1]), *dims[end+1:])
Example: Shape("B", 8, 32, 32) with flatten(1) → [B, 8192]
squeeze
Returns None if the specified dim is not size 1.
unsqueeze
Input: (*dims) (rank N)
Output: (*dims) with a new size-1 axis inserted at position dim.
Valid range: [-N-1, N]. Negative indices count from the end of the output tensor (-1 appends).
matmul / bmm
Input: (*, M, K) and (*, K, N) (both rank ≥ 2)
Output: (*, M, N) — batch dimensions are broadcast.
Emits TSF1003 on inner-dimension mismatch.
mm
Input: (M, K) and (K, N) — both must be rank 2 (no batch dimensions).
Output: (M, N).
Emits TSF1003 if inner dimensions do not match.
movedim
Input: (*dims) (rank N)
Output: (*dims) with the specified axes moved to new positions.
source and destination may be a single integer or a tuple of integers; negative indices are supported.
size
x.size() # → ShapeTupleValue (all dims)
x.size(dim) # → IntegerValue for that axis (constant dims only; symbolic → None)
Reduction methods
x.sum() # → scalar (rank 0)
x.sum(dim=1) # → removes axis 1
x.sum(dim=1, keepdim=True) # → keeps axis 1 as size 1
x.mean(dim=(1, 2)) # → removes axes 1 and 2 (tuple dim)
torch.amax(x, dim=1) # functional form — same semantics
Supported methods: sum, mean, max, min, amax, amin, prod, all, any, argmax, argmin, nanmean, nansum.
Available as both tensor methods (x.sum(...)) and torch.* functions (torch.sum(x, ...)).
dim may be an integer or a tuple of integers; negative indices are supported.
When no dim is given, the result is a rank-0 scalar tensor.
Arithmetic and broadcasting
Element-wise ops
Input: any two tensors with broadcast-compatible shapes (NumPy/PyTorch rules).
Output: broadcast result shape. Incompatible shapes emit TSF1006.
Augmented assignment
Treated as x = x <op> y; the result shape follows broadcast rules and updates the variable in the environment.
Useful for residual connections (x += residual).
torch.einsum
Supports explicit-mode subscripts (those containing ->) with single-character labels.
Ellipsis and implicit mode (no ->) are not yet supported.
# Batched matmul: (B, T, D) @ (B, D, T) → (B, T, T)
out = torch.einsum("bik,bkj->bij", q, k)
# Matrix-vector product: (M, K) @ (K,) → (M,)
out = torch.einsum("ij,j->i", A, v)
# Outer product: (M,) ⊗ (N,) → (M, N)
out = torch.einsum("i,j->ij", u, v)
Emits TSF1003 if contracted dimensions have mismatched sizes (constant or symbolic).
Sequence operations
torch.cat
All tensors must have the same rank and matching sizes on all axes except dim.
Output: the concatenated axis size is the sum of input sizes on that axis.
torch.stack
All tensors must have identical shapes.
Output: a new axis of size len(tensors) is inserted at position dim. Result rank = input rank + 1.
chunk
Input: (*dims)
Output: a tuple of n tensors.
- Constant dim, evenly divisible: each chunk has
dims[dim] // non the split axis. - Constant dim, not evenly divisible: first
n-1chunks haveceil(dims[dim] / n), last chunk has the remainder. - Symbolic dim: each chunk has
dims[dim]//nas an expression.
Supports tuple-unpacking: a, b, c = x.chunk(3, dim=-1).
split
split_size may be:
- An
int— splits into equal-ish chunks of that size (requires the axis to be a constant). - A
list[int]— splits into exactly those sizes. When the axis is a constant, the section sizes must sum to it.
Output: a tuple of TensorValue objects, one per chunk.
Supports tuple-unpacking: q, k, v = x.split(64, dim=-1).
Tensor expansion
expand / expand_as
expand broadcasts singleton dimensions to the given sizes; -1 keeps the original dimension unchanged.
expand_as expands to match the shape of other.
Input: (*dims)
Output: (*sizes) — leading dimensions may be added.
repeat
Repeats the tensor along each dimension (copies data, unlike expand).
Input: (*dims) (rank N)
Output: each dimension i is multiplied by repeats[i]. If len(repeats) > N, leading dimensions of 1 are prepended first.
Shape-preserving tensor methods
The following methods preserve the input shape exactly (dtype/device casts, memory layout, gradient management, in-place fill):
x.contiguous()
x.float() x.half() x.double()
x.int() x.long() x.short() x.byte() x.bool()
x.to(dtype_or_device)
x.detach()
x.clone()
x.cpu() x.cuda()
x.type(dtype)
x.masked_fill(mask, value)
x.fill_(value) x.zero_() x.normal_() x.uniform_()
x.requires_grad_(...)
Functional passthrough
The following torch.* and torch.nn.functional.* calls return the same shape as their first argument:
Activations: relu, relu_, leaky_relu, gelu, silu, sigmoid, tanh, elu, selu, mish, hardswish
Normalisation: layer_norm, batch_norm, group_norm, instance_norm, normalize
Attention / masking: softmax, log_softmax, triu, tril
Regularisation: dropout, dropout2d, dropout3d
Element-wise predicates / unary: flip, isfinite, isinf, isnan, abs, neg, sign
y = F.relu(x) # [B, T, D] → [B, T, D]
y = torch.softmax(x, dim=-1)
y = F.layer_norm(x, x.shape[-1:])
y = torch.triu(x)
F.scaled_dot_product_attention
Output shape equals query's shape.
Tensor constructors
Size-based constructors
torch.zeros(B, T, D)
torch.zeros((B, T, D)) # single tuple arg
torch.ones(size=(B, T)) # keyword size=
torch.empty(B, T, D)
torch.randn(B, T, D)
torch.rand(B, T, D)
torch.full((B, T), fill_value)
Size arguments can be integer constants (→ ConstantDim) or variable names from the environment (→ UnknownDim).
*_like constructors
torch.zeros_like(x)
torch.ones_like(x)
torch.empty_like(x)
torch.randn_like(x)
torch.rand_like(x)
torch.full_like(x, fill_value)
Output shape equals x's shape.
torch.arange
Output: rank-1 tensor. When all arguments are integer constants, the exact length is computed. Otherwise the dimension is unknown.
F.one_hot
Input: (*dims) — integer index tensor of any rank.
Output: (*dims, N) — the num_classes size is appended as a new trailing axis.
Spatial interpolation
F.interpolate
Batch and channel dimensions (first two) are always preserved. Only spatial dimensions (all dims beyond the first two) are resized.
Input: (N, C, *spatial) — rank ≥ 3.
Output: (N, C, *new_spatial)
sizeas a tuple:(H_out, W_out)— each spatial dim is replaced by the given constant.sizeas a variable (e.g.labels.shape[-2:]): evaluated at analysis time when possible.scale_factoras a float or tuple: each constant spatial dim is multiplied. Symbolic dims with integer scale factors produce expressions (e.g.2*H); non-integer factors produce?.
When neither size nor scale_factor can be resolved, no hover is emitted (silent pass).
y = F.interpolate(x, size=(64, 64), mode="bilinear") # [B, C, H, W] → [B, C, 64, 64]
y = F.interpolate(x, scale_factor=2.0) # [B, C, 16, 16] → [B, C, 32, 32]
y = F.interpolate(x, size=labels.shape[-2:]) # size from another tensor's shape
Advanced indexing and selection
x.diagonal / torch.diagonal
Removes dim1 and dim2 from the shape and appends the diagonal length.
When both dimensions are constants, diagonal length = max(0, min(d1, d2) - |offset|).
When both dimensions are the same symbolic dim and offset=0, the diagonal length equals that dim.
Otherwise, the diagonal length is ?.
Input: (*dims) — rank ≥ 2.
Output: (*remaining, diag_len)
x.index_select / torch.index_select
Replaces dims[dim] with the number of elements in index (its first dimension).
When index is a known-shape 1-D tensor, the exact count is tracked; otherwise ?.
Input: (*dims), index: (K,)
Output: (*dims) with dims[dim] replaced by K.
torch.topk
Both the values and indices output tensors have the same shape: the selected dimension becomes k.
Accessing .values or .indices on the result is handled — the shape is preserved.
Input: (*dims)
Output: (*dims[:dim], k, *dims[dim+1:])
torch.bincount
Returns a 1-D tensor whose length depends on the maximum value in input — not statically determinable.
Shape is reported as [?].
Indexing and slicing
x[0] # integer index — removes that dimension
x[1:5] # slice with constant bounds — size tracked (5-1 = 4)
x[1:] # open-ended slice — original dimension preserved
x[None] # newaxis — inserts a size-1 dimension
x[...] # ellipsis — passes through all remaining dimensions
x[0, :, None] # combinations of the above
Tensor attributes
x.shape # → ShapeTupleValue of all dims
x.ndim # → IntegerValue(rank)
x.shape[i] # → the i-th Dim (supports negative indices)
nn modules
Shape-preserving modules
The following module types pass the input shape through unchanged:
BatchNorm1d, BatchNorm2d, BatchNorm3d, LayerNorm, Dropout, Dropout2d, Dropout3d,
ReLU, LeakyReLU, GELU, SiLU, Sigmoid, Tanh, ELU, SELU, PReLU, Mish,
Hardswish, Hardsigmoid, Identity, Softmax
class Net(nn.Module):
def __init__(self):
self.bn = nn.BatchNorm2d(64)
self.act = nn.GELU()
self.drop = nn.Dropout(0.1)
def forward(self, x: Annotated[torch.Tensor, Shape("B", 64, "H", "W")]):
y = self.bn(x) # [B, 64, H, W]
z = self.act(y) # [B, 64, H, W]
w = self.drop(z) # [B, 64, H, W]
Module aliases are fully supported — act = self.act; y = act(x) works identically.
nn.Embedding
Input: (*indices) — any rank.
Output: (*indices, embedding_dim) — the embedding dimension is appended as a new trailing axis.
class Net(nn.Module):
def __init__(self):
self.emb = nn.Embedding(10000, 512)
def forward(self, x: Annotated[torch.Tensor, Shape("B", "T")]):
y = self.emb(x) # [B, T, 512]
nn.Linear
Input: (..., in_features)
Output: (..., out_features) — all leading dimensions are preserved.
If the last input dimension is a ConstantDim, it is validated against in_features; a mismatch emits TSF1007. A symbolic last dim skips the check and still propagates out_features.
nn.Conv2d
Input: (N, C_in, H, W) — must be rank 4.
Output: (N, C_out, H_out, W_out)
Formula: H_out = floor((H + 2·padding − dilation·(kernel−1) − 1) / stride + 1)
If the channel dimension is a ConstantDim, it is validated against in_channels; a mismatch emits TSF1007. A symbolic channel dim skips the check and still propagates the output shape.
nn.MaxPool2d / nn.AvgPool2d
nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1)
nn.AvgPool2d(kernel_size, stride=None, padding=0)
Input: (N, C, H, W) — must be rank 4.
Output: (N, C, H_out, W_out) — N and C are preserved.
Formula: H_out = floor((H + 2·padding − dilation·(kernel−1) − 1) / stride + 1)
When stride is omitted, PyTorch defaults it to kernel_size (both layers).
nn.AvgPool2d has no dilation parameter; it is implicitly (1, 1).
class Net(nn.Module):
def __init__(self):
self.pool = nn.MaxPool2d(2) # stride defaults to 2
def forward(self, x: Annotated[torch.Tensor, Shape("B", "C", 32, 32)]):
y = self.pool(x) # [B, C, 16, 16]
nn.Sequential
Shape is propagated through each sub-module in order. Supported sub-module types are the same as those recognized directly (Linear, Conv2d, MaxPool2d, AvgPool2d, Embedding, passthrough activations/norms).
class Net(nn.Module):
def __init__(self):
self.net = nn.Sequential(
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 16),
)
def forward(self, x: Annotated[torch.Tensor, Shape("B", 128)]):
y = self.net(x) # [B, 16]
nn.MultiheadAttention
When called with (query, key, value), returns a tuple (output, attn_weights):
output: same shape asquery.attn_weights: shape is not tracked statically (returned as[?, ?, ?]).
Tuple unpacking is supported:
class Net(nn.Module):
def __init__(self):
self.attn = nn.MultiheadAttention(64, 8, batch_first=True)
self.proj = nn.Linear(64, 32)
def forward(self, x: Annotated[torch.Tensor, Shape("B", "T", 64)]):
out, _ = self.attn(x, x, x) # out: [B, T, 64]
y = self.proj(out) # [B, T, 32]
Return shape validation
When a function's return type is annotated with Shape(...), the inferred shape of the return expression is compared against the declared shape. A mismatch raises TSF1009.
Only definite mismatches are reported (rank difference, or a constant-vs-constant dimension pair that differs). Symbolic dimensions are never flagged.
def fn(x: Annotated[torch.Tensor, Shape("B", 128)]) -> Annotated[torch.Tensor, Shape("B", 64)]:
return x # TSF1009: Return shape [B, 128] does not match declared [B, 64]
Near-term roadmap
Operator additions are driven by gaps in real PyTorch model code. Next likely additions:
torch.einsumwith ellipsis (...) and implicit modenn.MultiheadAttentionattention weight shape trackingnn.SequentialwithOrderedDictargument form
Every new operator requires tests before it ships. See Development.