Writing a PyTorch Backend
PyTorch sees SIDRA natively — model.to('sidra:0').
Prerequisites
What you'll learn here
- Write the PyTorch custom backend mechanism and register_backend API
- Explain dispatch (how ATen ops map to SIDRA ops)
- Summarize SIDRA implementations of MVM, softmax, activation
- Describe performance profiling and benchmarking strategies
- Compare with other frameworks (TF, JAX)
Hook: model.to('sidra:0')
PyTorch holds 70%+ of AI usage. PyTorch support is mandatory for SIDRA.
Goal: model.to("sidra:0") → the model runs on SIDRA automatically. Like a CUDA backend.
This chapter walks through the PyTorch custom backend API and how SIDRA integrates.
Intuition: Dispatcher + Op Registration
PyTorch tensor ops go through a dispatcher:
torch.matmul(a, b)
↓
ATen dispatcher
↓ (per device)
CPU: native C impl
CUDA: cuda_matmul
SIDRA: sidra_matmul ← new!Per-op device-specific implementation. SIDRA routes MVM to the crossbar.
Formalism: Custom Backend Registration
PrivateUse1 backend (PyTorch 2.0+):
# SIDRA backend init
import torch
from sidra_pytorch import SidraDevice
# Register
torch.register_privateuse1_backend("sidra")
torch._register_device_module("sidra", SidraDevice())
# Use
device = torch.device("sidra:0")
tensor = torch.randn(10, 10, device=device) # in SIDRA memoryDispatch override:
# Custom kernel for MVM
@torch.library.impl("aten::matmul", "PrivateUse1")
def sidra_matmul(a, b):
# Crossbar MVM via SIDRA SDK
return sidra.crossbar_mvm(a, b)When PyTorch calls aten::matmul on the SIDRA backend, this function runs.
Supported ops (priority order):
- matmul, linear, conv → SIDRA crossbar.
- relu, gelu, softmax → SIDRA compute engine.
- layernorm, batchnorm → compute engine.
- add, mul, reshape → usually CPU or fallback.
Not every op needs to live on SIDRA. Just the “hotspot” ops; the rest stays on CPU.
Graph mode:
PyTorch 2.0 torch.compile optimizes the model graph:
@torch.compile(backend="sidra")
def model_fn(x):
return model(x)Compile time:
- Extract model graph.
- Apply SIDRA backend passes (operator fusion, quantization).
- Crossbar mapping.
- Emit optimized code.
Runtime: direct crossbar inference.
Quantization-aware:
# INT8 quantize
quantized_model = torch.quantization.quantize_dynamic(
model, {torch.nn.Linear}, dtype=torch.qint8
)
# Move to SIDRA
quantized_model.to("sidra:0")SIDRA INT8 native. FP32 → INT8 quantization loses ~0.5% accuracy.
Fallback:
Unsupported (exotic) op → fallback to CPU. Driver moves data automatically.
# Hybrid execution
def forward(x):
x = layer1(x) # SIDRA
x = custom_op(x) # CPU fallback (auto)
x = layer2(x) # SIDRA
return x Benchmark:
import torch
import time
model = load_model("bert-base").to("sidra:0")
x = torch.randn(1, 512, 768).to("sidra:0")
# Warmup
for _ in range(10):
y = model(x)
torch.sidra.synchronize()
# Benchmark
t0 = time.time()
for _ in range(1000):
y = model(x)
torch.sidra.synchronize()
t1 = time.time()
print(f"Latency: {(t1-t0)/1000*1000:.2f} ms/inference")Typical: BERT-base on Y1 ~5 ms/inference. H100 ~2 ms. SIDRA is 2.5× slower but 50× more energy-efficient.
Profiler:
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.SIDRA]
) as prof:
y = model(x)
print(prof.key_averages().table())
# Output: matmul 2.1 ms, softmax 0.3 ms, ...For finding hotspots.
Training (limited on Y1):
# Fine-tune the last layer (Y1 hybrid)
for param in model.parameters():
param.requires_grad = False
model.classifier.requires_grad = True
optimizer = torch.optim.Adam(model.classifier.parameters())
for epoch in range(5):
for batch in loader:
y = model(batch.x.to("sidra:0"))
loss = criterion(y, batch.y)
loss.backward() # Gradient compute on CPU
optimizer.step() # SIDRA last-layer ISPP updateY1: last-layer training only. Y10+: full training.
TensorFlow + JAX:
- TensorFlow: XLA backend. SIDRA plugin. Limited, in development.
- JAX: XLA-based. SIDRA JAX backend Y10+ target.
PyTorch is the priority because it’s 70% of the market.
Experiment: PyTorch → SIDRA in 30 Seconds
# 1. Import
import torch
import sidra_pytorch
# 2. Backend register
torch.register_privateuse1_backend("sidra")
# 3. Load model
from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
# 4. Move to SIDRA
model = model.to("sidra:0")
# 5. Inference
inputs = tokenizer("Hello SIDRA", return_tensors="pt").to("sidra:0")
outputs = model(**inputs)
print(outputs.logits)Standard HuggingFace code. SIDRA only adds .to("sidra:0"). Transparent.
Quick Quiz
Lab Exercise
Typical workflow: HuggingFace transformer → SIDRA inference.
from transformers import pipeline
import sidra_pytorch
# Backend register (once)
sidra.register_pytorch_backend()
# Pass sidra device to pipeline
translator = pipeline(
"translation_en_to_tr",
model="Helsinki-NLP/opus-mt-en-tr",
device="sidra:0"
)
result = translator("Hello, how are you?")
print(result) # [{'translation_text': 'Merhaba, nasılsın?'}]5 lines of code. Download the model, move to SIDRA, use it. Underneath: INT8 quantize, crossbar program, inference.
Time: model download 1 minute. First inference 2 seconds. Each subsequent one 10 ms.
Energy: ~30 mJ/sentence (100 tokens). GPU 1 J (30× more).
Cheat Sheet
- PyTorch 2.0+ PrivateUse1 backend.
- Dispatch: ATen ops map to SIDRA kernels.
- torch.compile: graph-level optimization.
- Quantization: INT8 auto.
- Fallback: CPU for unsupported ops.
- Benchmark: Y1 BERT ~5 ms.
- Training: Y1 last layer; full Y10+.
Vision: SIDRA + PyTorch Ecosystem
- Y1: PyTorch inference standard.
- Y3: HuggingFace ✓, Stable Diffusion ✓, Whisper ✓.
- Y10: Training + PyTorch Lightning.
- Y100: Native PyTorch primitives (attention, transformer blocks).
- Y1000: PyTorch alternative — SIDRA-native framework.
For Türkiye: Turkish models on HuggingFace with the SIDRA backend → Turkish NLP community as natural users.
Further Reading
- Next chapter: 6.7 — Compiler: Model → Analog Mapping
- Previous: 6.5 — SDK Layers
- PyTorch backend: pytorch.org PrivateUse1 guide.
- torch.compile: pytorch.org 2.0 release notes.
- HuggingFace: huggingface.co transformers.