"""
********************************************************************************
* Copyright (c) 2026 the Qrisp authors
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* This Source Code may also be made available under the following Secondary
* Licenses when the conditions for such availability set forth in the Eclipse
* Public License, v. 2.0 are satisfied: GNU General Public License, version 2
* with the GNU Classpath Exception which is
* available at https://www.gnu.org/software/classpath/license.html.
*
* SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0
********************************************************************************
"""
"""
Automatic detector identification for quantum error correction circuits.
This module provides the :func:`find_detectors` decorator, which leverages
`tqecd <https://github.com/tqec/tqecd>`_ to automatically identify
stim ``DETECTOR`` instructions at Jasp tracing time and emit the
corresponding ``parity()`` calls.
Usage
-----
.. code-block:: python
from qrisp import *
from qrisp.jasp.evaluation_tools.stim_extraction import extract_stim
from qrisp.misc.stim_tools.find_detectors import find_detectors
@find_detectors
def syndrome_round(qa):
reset(qa[1]); reset(qa[3])
cx(qa[0], qa[1]); cx(qa[2], qa[1])
cx(qa[2], qa[3]); cx(qa[4], qa[3])
return measure(qa[1]), measure(qa[3])
@extract_stim
def main():
qa = QuantumArray(qtype=QuantumBool(), shape=(5,))
detectors, m1, m3 = syndrome_round(qa)
return detectors, m1, m3
"""
import stim
import numpy as np
import jax
from jax.tree_util import tree_flatten
from qrisp import QuantumArray, QuantumBool
from qrisp.jasp import make_jaspr
from qrisp.jasp.primitives.parity_primitive import parity
from qrisp.circuit import QuantumCircuit as QC
# ───────────────────── stim instruction categories ─────────────────────
_RESET_OPS = frozenset({"R", "RX", "RY"})
_MEASUREMENT_OPS = frozenset({"M", "MX", "MY", "MZ", "MR", "MRX", "MRY"})
_TWO_QUBIT_OPS = frozenset({"CX", "CY", "CZ", "XCZ", "YCZ", "SWAP", "ISWAP"})
# ───────────────────── internal helpers ────────────────────────────────
def _tvals(inst):
"""Return plain target-value list for a stim instruction."""
return [t.value for t in inst.targets_copy()]
def _build_analysis_args(jaspr, n_qubits, quantum_arrays):
"""Build concrete args for calling ``to_qc`` on *jaspr*.
Iterates over the jaspr's invars and replaces ``QubitArray`` types with
concrete qubit lists from an analysis circuit. Static parameters have
been bound via closure before tracing, so only QuantumArray-derived
invars remain.
"""
qubits = list(QC(n_qubits).qubits)
# Partition qubits into slices corresponding to each QuantumArray arg
off, slices = 0, []
for a in quantum_arrays:
n = a.size * a.qtype.size
slices.append(qubits[off : off + n])
off += n
it = iter(slices)
args = []
for v in jaspr.invars[:-1]: # exclude trailing QC
s = str(v.aval)
if s == "QubitArray":
args.append(next(it))
elif "int" in s and getattr(v.aval, "shape", ()):
args.append(np.arange(n_qubits)) # ind_array for qubit indexing
elif "int" in s or "bool" in s:
args.append(1) # qtype_size (always 1 for QuantumBool)
else:
raise ValueError(f"Unexpected invar: {s}")
return args
def _prepare_for_tqecd(stim_circ):
"""Re-structure *stim_circ* into the fragment format expected by tqecd.
tqecd requires each *fragment* to follow this layout:
1. Zero or more moments of reset operations
2. Zero or more moments of computation gates (no resets or measurements)
3. Exactly one moment of measurement operations
Moments are separated by ``TICK`` instructions. Multi-round circuits are
split at measurement boundaries so each round becomes its own fragment.
The first fragment receives implicit ``R`` (Z-basis reset) for every qubit
that was not explicitly reset by the user's circuit.
"""
# Collect all qubit indices
all_qubits = {
t.value
for i in stim_circ.flattened()
for t in i.targets_copy()
if not t.is_measurement_record_target
}
# Split instructions into rounds.
# A round boundary occurs when we've accumulated measurements and the
# next instruction is no longer a measurement.
rounds, cur = [], {"R": [], "G": [], "M": []}
for inst in stim_circ.flattened():
if inst.name == "TICK":
continue
if cur["M"] and inst.name not in _MEASUREMENT_OPS:
rounds.append(cur)
cur = {"R": [], "G": [], "M": []}
key = (
"R"
if inst.name in _RESET_OPS
else "M" if inst.name in _MEASUREMENT_OPS else "G"
)
cur[key].append(inst)
if cur["M"]:
rounds.append(cur)
out = stim.Circuit()
for q in sorted(all_qubits):
out.append("QUBIT_COORDS", [q], [float(q), 0.0])
for ri, rnd in enumerate(rounds):
# First round: implicit R for un-reset qubits (tqecd requires it)
reset_qbs = {t.value for i in rnd["R"] for t in i.targets_copy()}
if ri == 0 and (unreset := sorted(all_qubits - reset_qbs)):
out.append("R", unreset)
for i in rnd["R"]:
out.append(i.name, _tvals(i), i.gate_args_copy())
out.append("TICK")
# Gate moments — two-qubit gates each get their own moment
for i in rnd["G"]:
tgts, ga = i.targets_copy(), i.gate_args_copy()
if i.name in _TWO_QUBIT_OPS:
for j in range(0, len(tgts), 2):
out.append(i.name, [tgts[j].value, tgts[j + 1].value], ga)
out.append("TICK")
else:
out.append(i.name, _tvals(i), ga)
out.append("TICK")
# Measurements
for i in rnd["M"]:
out.append(i.name, _tvals(i), i.gate_args_copy())
if ri < len(rounds) - 1:
out.append("TICK")
return out
def _extract_detectors(annotated_circ, total_meas, inv_meas_map):
"""Extract detector definitions from an annotated stim circuit.
Returns a list of lists of :class:`~qrisp.circuit.Clbit`, one list per
``DETECTOR`` instruction.
"""
return [
[
inv_meas_map[total_meas + t.value]
for t in inst.targets_copy()
if t.is_measurement_record_target
]
for inst in annotated_circ.flattened()
if inst.name == "DETECTOR"
]
# ───────────────────── public API ──────────────────────────────────────
[docs]
def find_detectors(func=None, *, return_circuits=False):
r"""Decorator that automatically identifies stim detectors and returns them.
``find_detectors`` leverages `tqecd <https://github.com/tqec/tqecd>`_
to automatically discover detector parity checks in quantum error correction
circuits. The detected parities are returned as an additional value alongside
the decorated function's original return values.
.. note::
This feature requires the optional ``tqecd`` package. Install it
as described `here <https://tqec.github.io/tqecd/index.html>`_.
Pitfalls & Limitations
^^^^^^^^^^^^^^^^^^^^^^
The ``find_detectors`` feature comes with some constraints:
**QuantumArray arguments must be QuantumBool**
All ``QuantumArray`` arguments to the decorated function must have
``qtype=QuantumBool()``. Other quantum types will raise a
``TypeError``. Static parameters (integers, booleans, etc.) are
allowed and will be passed through to the function unchanged.
**Only Clifford gates are supported**
tqecd analyses stabilizer flows, which are only well-defined for
Clifford circuits. If the decorated function contains non-Clifford
gates (e.g. T, Toffoli, arbitrary-angle rotations), stim will raise
an error during circuit construction or tqecd will silently produce
incorrect results.
**No real-time classical computation inside the decorated function**
The circuit is traced symbolically. Conditional logic that depends
on measurement outcomes (real-time ``if``/``else``) cannot be
expressed in the stim circuit and will break analysis.
**Composite detectors from flow products are not discoverable**
tqecd finds detectors by tracking individual stabilizer flows from
resets through gates to measurements. It **cannot** discover
composite detectors that arise as products of multiple stabilizer
flows. For example, in a GHZ-state circuit the parity
:math:`Z_3 Z_4` is a product of :math:`Z_0 Z_3` and
:math:`Z_0 Z_4`, but tqecd will not identify it.
**Multi-round circuits require explicit resets**
For tqecd to detect round-to-round detectors (comparing syndrome
measurements across rounds), ancilla qubits **must** be explicitly
reset between rounds. Without resets, tqecd cannot establish the
start of a new stabilizer flow and will miss cross-round detectors.
This is critical: in a two-round repetition code without resets
between rounds, tqecd will only find first-round detectors and
miss the temporal detectors that compare round 1 vs round 2.
**Detectors referencing non-returned measurements are discarded**
If a detector involves a measurement whose result is *not* part of
the decorated function's return value, the detector is silently
dropped. Make sure all relevant measurements are returned.
For example, if you measure both syndrome qubits but only return
one measurement, detectors involving the unreturned measurement
will be filtered out.
**Multiple QuantumArray arguments receive disjoint qubit slices**
When the decorated function takes more than one QuantumArray
argument, it is assumed that each argument indeed represents
a distinct set of qubits. This can not be guaranteed at tracing
time and needs to be ensured from user side.
Each array is mapped to its own contiguous slice of
qubits during analysis. The slices are allocated in positional
order.
**May discover more detectors than expected**
In some circuits, ``find_detectors`` may identify *more* valid
detectors than a minimal set. For example, in 2-round codes,
tqecd can discover boundary detectors from data qubit measurements
that are mathematically valid but redundant with syndrome-based
detectors. All discovered detectors are correct (they will sample
to zero in noiseless circuits); the extra ones represent additional
stabilizer flows through the circuit.
Parameters
----------
func : callable, optional
The function to decorate. When using keyword arguments (e.g.
``@find_detectors(return_circuits=True)``), *func* is ``None`` and
the decorator returns a wrapper that accepts *func*.
The decorated function must accept at least one ``QuantumArray``
argument (with ``qtype=QuantumBool()``). Additional non-QuantumArray
arguments (integers, booleans, strings, etc.) are treated as static
parameters: they are captured by closure before JAX tracing so they
behave as compile-time constants inside the function body.
return_circuits : bool, default ``False``
When ``True``, three additional items are appended to the return
value (after the detector list and the original returns):
* **raw_stim** — the stim circuit obtained directly from ``to_qc``
``to_stim``, before any restructuring.
* **tqecd_input** — the circuit after ``_prepare_for_tqecd`` has
re-arranged it into proper fragments (the input fed to tqecd).
* **annotated** — the circuit that comes back from
``tqecd.annotate_detectors_automatically``, containing the
``DETECTOR`` instructions.
This is useful for debugging or understanding why detectors were
or were not discovered.
Returns
-------
The decorated function returns:
(detector_bools, \*original_returns)
or, when ``return_circuits=True``:
(detector_bools, \*original_returns, raw_stim, tqecd_input, annotated)
where *detector_bools* is a ``list`` of Jasp-traced boolean values, one
per discovered detector, each representing the ``parity()`` of the
measurements that constitute that detector. The original return values
from the decorated function are unpacked and appended after the detector list.
Examples
--------
**Example 1: Single-round syndrome extraction**
Three data qubits with two ancilla parity checks::
from qrisp import *
from qrisp.jasp.evaluation_tools.stim_extraction import extract_stim
from qrisp.misc.stim_tools import find_detectors
@find_detectors
def syndrome(data, ancilla):
# Reset ancilla qubits
reset(ancilla[0])
reset(ancilla[1])
# Syndrome extraction
cx(data[0], ancilla[0])
cx(data[1], ancilla[0])
cx(data[1], ancilla[1])
cx(data[2], ancilla[1])
# Measure syndromes
return measure(ancilla[0]), measure(ancilla[1])
@extract_stim
def main():
data = QuantumArray(qtype=QuantumBool(), shape=(3,))
ancilla = QuantumArray(qtype=QuantumBool(), shape=(2,))
detectors, m0, m1 = syndrome(data, ancilla)
return detectors, m0, m1
# Run and extract stim circuit
result = main()
stim_circ = result[-1]
print(f"Found {stim_circ.num_detectors} detectors")
**Example 2: Two-round repetition code**
Distance-3 repetition code with temporal detectors comparing rounds::
@find_detectors
def rep_code_2rounds(data, ancilla):
measurements = []
for round_num in range(2):
# Reset ancillas
reset(ancilla[0])
reset(ancilla[1])
# Parity checks
cx(data[0], ancilla[0])
cx(data[1], ancilla[0])
cx(data[1], ancilla[1])
cx(data[2], ancilla[1])
# Measure syndromes
measurements.append(measure(ancilla[0]))
measurements.append(measure(ancilla[1]))
# Final data readout
for i in range(3):
measurements.append(measure(data[i]))
return tuple(measurements)
@extract_stim
def main():
data = QuantumArray(qtype=QuantumBool(), shape=(3,))
ancilla = QuantumArray(qtype=QuantumBool(), shape=(2,))
detectors, *measurements = rep_code_2rounds(data, ancilla)
return (detectors,) + tuple(measurements)
result = main()
stim_circ = result[-1]
# Will find >= 6 detectors (may discover more valid boundary detectors):
# - 2 detectors in round 1 (vs initial reset)
# - 2 detectors in round 2 (vs round 1)
# - 2+ data-boundary detectors (tqecd may find additional valid ones)
print(f"Found {stim_circ.num_detectors} detectors")
**Example 3: Debug mode with circuit inspection**
Use ``return_circuits=True`` to inspect intermediate circuits::
@find_detectors(return_circuits=True)
def debug_syndrome(data, ancilla):
reset(ancilla[0])
cx(data[0], ancilla[0])
cx(data[1], ancilla[0])
return measure(ancilla[0])
@extract_stim
def main():
data = QuantumArray(qtype=QuantumBool(), shape=(2,))
ancilla = QuantumArray(qtype=QuantumBool(), shape=(1,))
detectors, m, raw_stim, tqecd_input, annotated = debug_syndrome(data, ancilla)
print("Raw stim circuit:")
print(raw_stim)
print("\nAfter tqecd restructuring:")
print(tqecd_input)
print("\nWith detectors annotated:")
print(annotated)
return detectors, m
main()
**Example 4: Multi-array arguments**
Separate data and ancilla qubits::
@find_detectors
def surface_check(data_qubits, ancilla_qubits):
# Reset ancilla
reset(ancilla_qubits[0])
# X-stabilizer (Hadamard + CNOTs)
h(ancilla_qubits[0])
cx(ancilla_qubits[0], data_qubits[0])
cx(ancilla_qubits[0], data_qubits[1])
cx(ancilla_qubits[0], data_qubits[2])
cx(ancilla_qubits[0], data_qubits[3])
h(ancilla_qubits[0])
return measure(ancilla_qubits[0])
@extract_stim
def main():
data = QuantumArray(qtype=QuantumBool(), shape=(4,))
ancilla = QuantumArray(qtype=QuantumBool(), shape=(1,))
detectors, m = surface_check(data, ancilla)
return detectors, m
result = main()
# Note: This finds 0 detectors because the X-stabilizer measurement
# on Z-basis |0⟩ initialized qubits is not deterministic.
# For detectors, would need X-basis initialization (RX) on data qubits.
**Example 5: Three-round code with explicit resets**
Shows importance of reset between rounds::
@find_detectors
def three_rounds(data, ancilla):
results = []
for _ in range(3):
# CRITICAL: Reset between rounds
reset(ancilla[0])
# Syndrome extraction
cx(data[0], ancilla[0])
cx(data[1], ancilla[0])
results.append(measure(ancilla[0]))
return tuple(results)
@extract_stim
def main():
data = QuantumArray(qtype=QuantumBool(), shape=(2,))
ancilla = QuantumArray(qtype=QuantumBool(), shape=(1,))
detectors, m1, m2, m3 = three_rounds(data, ancilla)
return detectors, m1, m2, m3
result = main()
stim_circ = result[-1]
# Will find >= 3 detectors (one per round plus temporal)
print(f"Found {stim_circ.num_detectors} detectors")
# Verify all detectors are valid (noiseless samples = 0)
sampler = stim_circ.compile_detector_sampler()
samples = sampler.sample(shots=1000)
assert samples.sum() == 0, "All detectors should be deterministic!"
**Example 6: Partial measurement return**
Only detectors involving returned measurements are kept::
@find_detectors
def partial_return(data, ancilla):
reset(ancilla[0])
reset(ancilla[1])
cx(data[0], ancilla[0])
cx(data[1], ancilla[0])
cx(data[1], ancilla[1])
cx(data[2], ancilla[1])
m0 = measure(ancilla[0])
m1 = measure(ancilla[1])
# Only return m0, discard m1
return m0
@extract_stim
def main():
data = QuantumArray(qtype=QuantumBool(), shape=(3,))
ancilla = QuantumArray(qtype=QuantumBool(), shape=(2,))
detectors, m = partial_return(data, ancilla)
return detectors, m
result = main()
stim_circ = result[-1]
# Only detectors involving m0 (the returned measurement) are kept.
# Detectors that depend on m1 (unreturned) are filtered out.
# Typically finds 1 detector from the m0 measurement vs initial reset.
print(f"Found {stim_circ.num_detectors} detectors")
**Example 7: Static parameters (integers, booleans)**
The decorated function can accept non-QuantumArray arguments such as
integers or booleans. These are bound via closure before tracing so
that JAX treats them as compile-time constants::
@find_detectors
def configurable_syndrome(data, ancilla, num_rounds, use_extra_gate):
results = []
for _ in range(num_rounds):
reset(ancilla[0])
cx(data[0], ancilla[0])
cx(data[1], ancilla[0])
if use_extra_gate:
cx(data[1], ancilla[0])
cx(data[1], ancilla[0]) # cancels out
results.append(measure(ancilla[0]))
return tuple(results)
@extract_stim
def main():
data = QuantumArray(qtype=QuantumBool(), shape=(2,))
ancilla = QuantumArray(qtype=QuantumBool(), shape=(1,))
detectors, *ms = configurable_syndrome(data, ancilla, 3, True)
return (detectors,) + tuple(ms)
result = main()
stim_circ = result[-1]
print(f"Found {stim_circ.num_detectors} detectors")
"""
try:
import tqecd
except ImportError:
raise ImportError(
"find_detectors requires the 'tqecd' package. "
"Install it with: pip install tqecd"
)
def _decorator(fn):
def wrapper(*args, **kwargs):
# --- Validate & classify arguments ---
quantum_arrays = []
for a in args:
if isinstance(a, QuantumArray):
if not isinstance(a.qtype, QuantumBool):
raise TypeError(
f"find_detectors: QuantumArray args must have "
f"qtype QuantumBool, got {a.qtype}"
)
quantum_arrays.append(a)
if not quantum_arrays:
raise TypeError(
"find_detectors: at least one QuantumArray argument is required"
)
n_qubits = sum(a.size * a.qtype.size for a in quantum_arrays)
# --- 1. Trace inner function to get sub-jaspr ---
# Bind non-QuantumArray args via closure so JAX only traces
# the quantum arguments. Static values (ints, bools, …)
# become constants inside the jaspr.
qa_idx = [i for i, a in enumerate(args) if isinstance(a, QuantumArray)]
static = {
i: a for i, a in enumerate(args) if not isinstance(a, QuantumArray)
}
def fn_qa_only(*qa_args, _kw=kwargs):
full = [None] * len(args)
for i, q in zip(qa_idx, qa_args):
full[i] = q
for i, v in static.items():
full[i] = v
return fn(*full, **_kw)
qa_args = tuple(args[i] for i in qa_idx)
# Count the expected number of flat leaves from the
# QuantumArray args. JAX's make_jaxpr creates one invar per
# leaf plus a trailing QC invar appended by Qrisp.
expected_n_invars = len(tree_flatten(qa_args)[0]) + 1 # +1 for QC
_TRACER_MSG = (
"find_detectors: a JAX tracer was leaked into the "
"decorated function via a non-QuantumArray argument. "
"All non-QuantumArray parameters (including values "
"nested inside dicts, lists, etc.) must be concrete "
"Python values, not traced quantities."
)
try:
jaspr = make_jaspr(fn_qa_only)(*qa_args)
except (
jax.errors.UnexpectedTracerError,
jax.errors.TracerIntegerConversionError,
TypeError,
) as exc:
raise TypeError(_TRACER_MSG) from exc
# --- 2. Analysis: to_qc → stim → tqecd ---
result = jaspr.to_qc(*_build_analysis_args(jaspr, n_qubits, quantum_arrays))
# Flatten analysis Clbits (consistent ordering via tree_flatten)
analysis_clbits, _ = tree_flatten(result[:-1])
# Convert to stim
raw_stim, meas_map = result[-1].to_stim(return_measurement_map=True)
inv_map = {v: k for k, v in meas_map.items()}
# Prepare for tqecd and annotate
tqecd_circ = _prepare_for_tqecd(raw_stim)
annotated = tqecd.annotate_detectors_automatically(tqecd_circ)
# Extract detector Clbit sets
all_dets = _extract_detectors(
annotated, tqecd_circ.num_measurements, inv_map
)
# Keep only detectors whose Clbits all appear in returned measurements
returned = set(analysis_clbits)
relevant = [
(i, d)
for i, d in enumerate(all_dets)
if all(cb in returned for cb in d)
]
# --- 2b. Simulate noiseless circuit to determine detector expectations ---
# Use stim's built-in method to remove all noise processes
expectations = (
annotated.without_noise()
.compile_detector_sampler()
.sample(shots=1)[0]
.tolist()
)
# --- 3. Call real function & map analysis Clbits to traced booleans ---
real_returns = fn(*args, **kwargs)
if not isinstance(real_returns, tuple):
real_returns = (real_returns,)
traced_flat, _ = tree_flatten(real_returns)
cb_to_traced = dict(zip(analysis_clbits, traced_flat))
# --- 4. Emit parity calls for each detector ---
det_results = [
parity(
*[cb_to_traced[cb] for cb in det],
expectation=int(expectations[idx]),
)
for idx, det in relevant
]
out = (det_results,) + real_returns
if return_circuits:
out += (raw_stim, tqecd_circ, annotated)
return out
return wrapper
# Support both @find_detectors and @find_detectors(return_circuits=True)
if func is not None:
return _decorator(func)
return _decorator