Skip to main content

Import Guide

Convert T1C-IR graphs back to executable PyTorch modules using t1c.bridge.

Basic Import

from t1c import ir, bridge

# Load from file
graph = ir.read('model.t1c')

# Convert to PyTorch module
executor = bridge.ir_to_torch(graph)

# Run inference
output = executor(input_tensor)

Stateful Execution

For models with LIF neurons, enable state tracking:

executor = bridge.ir_to_torch(graph, return_state=True)

# Initialize state
state = {}

# Run timesteps
for t in range(num_timesteps):
output, state = executor(input_tensor, state)
print(f"t={t}: output={output.sum().item():.2f}")

The state dictionary maps node names to their internal states:

print(state.keys())  # dict_keys(['lif1', 'lif2'])

Using snnTorch Wrapper

If using the snnTorch fork (wrappers use t1c.ir and t1c.bridge under the hood):

from snntorch.import_t1cir import import_from_ir

executor = import_from_ir('model.t1c', return_state=True)

GraphExecutor

The imported module is a GraphExecutor instance:

from t1c.bridge import GraphExecutor

executor = bridge.ir_to_torch(graph, return_state=True)
print(type(executor)) # <class 't1c.bridge.executor.GraphExecutor'>

# Access submodules
for name, module in executor.named_children():
print(f"{name}: {type(module).__name__}")

Output:

input: Identity
fc1: Linear
lif1: LIFModule
fc2: Linear
lif2: LIFModule
output: Identity

LIFModule

The LIF neuron is imported as LIFModule:

from t1c.bridge import LIFModule

# Access parameters
lif = executor.lif1
print(f"tau: {lif.tau}")
print(f"beta: {lif.beta}") # Computed: 1 - 1/tau
print(f"threshold: {lif.v_threshold}")

Forward Signature

def forward(self, x: torch.Tensor, mem: Optional[torch.Tensor] = None) -> tuple:
"""
Args:
x: Input current
mem: Previous membrane potential (optional)

Returns:
(spikes, membrane_potential)
"""

Dynamics

Implements NIR-compliant discretized LIF:

# Update membrane
mem = beta * mem + (1 - beta) * (v_leak + r * x)

# Generate spikes
spk = (mem >= v_threshold).float()

# Reset
mem = mem * (1 - spk)

return spk, mem

Custom Node Mapping

Override default import for specific node types:

def my_lif_converter(node: ir.LIF) -> nn.Module:
"""Custom LIF that uses snnTorch Leaky instead."""
import snntorch as snn
beta = 1.0 - 1.0 / node.tau[0]
return snn.Leaky(beta=beta)

executor = bridge.ir_to_torch(
graph,
node_map=my_lif_converter,
return_state=True
)

Low-Level Import

For more control, use the module-level functions:

from t1c.bridge import from_ir, GraphExecutor

# Get dict of modules
modules = from_ir(graph)
print(modules.keys()) # dict_keys(['input', 'fc1', 'lif1', ...])

# Build executor manually
executor = GraphExecutor(
modules=modules,
edges=graph.edges,
return_state=True
)

Node-by-Node Conversion

Convert individual nodes:

from t1c.bridge import node_to_module

# Convert single node
lif_node = graph.nodes['lif1']
lif_module = node_to_module(lif_node)
print(type(lif_module)) # <class 't1c.bridge.from_ir.LIFModule'>

Supported Nodes

T1C-IRPyTorchNotes
Inputnn.IdentityGraph marker
Outputnn.IdentityGraph marker
Affinenn.Linear
SpikingAffinenn.LinearQuantization hints in metadata
Conv2dnn.Conv2d
Flattennn.Flatten
MaxPool2dnn.MaxPool2d
AvgPool2dnn.AvgPool2d
Upsamplenn.UpsampleSpatial upsampling (FPN necks)
SepConv2dSepConv2dModulePreserves identity for round-trip
SkipSkipModulePreserves skip_type
LIFLIFModuleStateful neuron

SepConv2dModule

Depthwise separable convolution preserving T1C-IR identity for round-trip:

from t1c.bridge import SepConv2dModule

# Structure: depthwise (groups=in_ch) -> pointwise (1x1)
sepconv = executor.sepconv1
print(f"Depthwise: {sepconv.depthwise}")
print(f"Pointwise: {sepconv.pointwise}")
print(f"Stride: {sepconv.stride}")

The module stores all configuration needed for export back to T1C-IR:

  • depthwise: nn.Conv2d with groups=in_channels
  • pointwise: nn.Conv2d with 1x1 kernel
  • stride, padding, dilation: Applied to depthwise conv

SkipModule

The Skip primitive is imported as SkipModule, which preserves the skip_type for proper merge behavior:

from t1c.bridge import SkipModule

skip = SkipModule(skip_type='residual')
print(skip.skip_type) # 'residual'

Skip Types and Merge Behavior

When multiple edges feed into a Skip node, the GraphExecutor performs the merge based on skip_type:

skip_typeMerge OperationUse Case
passthroughIdentity (single input)SPPF identity branch
residualElement-wise addResNet blocks
concatenateChannel concat (dim=1)SPP, DenseNet

Example of residual merge:

# Graph: input -> fc -> skip, input -> skip (residual connection)
modules = {
'input': nn.Identity(),
'fc': nn.Linear(10, 10),
'skip': SkipModule('residual'), # Adds inputs together
'output': nn.Identity(),
}
edges = [
('input', 'fc'),
('fc', 'skip'), # main path
('input', 'skip'), # residual path (added to main)
('skip', 'output'),
]

executor = GraphExecutor(modules, edges)
# Output: fc(x) + x

Example of concatenate merge:

modules = {
'input': nn.Identity(),
'pool5': nn.MaxPool2d(5, stride=1, padding=2),
'pool9': nn.MaxPool2d(9, stride=1, padding=4),
'concat': SkipModule('concatenate'), # Concatenates channels
'output': nn.Identity(),
}
edges = [
('input', 'pool5'),
('input', 'pool9'),
('pool5', 'concat'), # 64 channels
('pool9', 'concat'), # 64 channels
('concat', 'output'), # 128 channels (concatenated)
]

Round-Trip Verification

Verify export/import produces identical results:

import torch

# Original model
model = SNN()
model.eval()

# Export
sample = torch.randn(1, 784)
graph = bridge.to_ir(model, sample)

# Import
executor = bridge.ir_to_torch(graph, return_state=True)
executor.eval()

# Compare outputs
x = torch.randn(1, 784)
state = {}

with torch.no_grad():
# Original
mem1 = model.lif1.init_leaky()
mem2 = model.lif2.init_leaky()
out_orig, _, _ = model(x, mem1, mem2)

# Imported
out_import, state = executor(x, state)

# Should be very close
print(f"Max diff: {(out_orig - out_import).abs().max().item()}")
assert torch.allclose(out_orig, out_import, atol=1e-5)

Troubleshooting

State Key Errors

If state dict has wrong keys:

# Check expected keys
print([name for name, m in executor.named_children()
if isinstance(m, LIFModule)])

Shape Mismatch

If input shapes don't match:

# Check graph input shape
input_node = graph.nodes['input']
expected_shape = input_node.input_type['input']
print(f"Expected: {expected_shape}")

Missing Node Type

If a node type isn't supported:

from t1c.bridge import node_to_module

# Returns None for unsupported types
result = node_to_module(unknown_node)
if result is None:
print(f"Unsupported: {type(unknown_node)}")