Source code for qrisp.jasp.evaluation_tools.profiler

"""
********************************************************************************
* Copyright (c) 2026 the Qrisp authors
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* This Source Code may also be made available under the following Secondary
* Licenses when the conditions for such availability set forth in the Eclipse
* Public License, v. 2.0 are satisfied: GNU General Public License, version 2
* with the GNU Classpath Exception which is
* available at https://www.gnu.org/software/classpath/license.html.
*
* SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0
********************************************************************************
"""

"""
This file implements the tools to perform quantum resource estimation using Jasp
infrastructure. The idea here is to transform the quantum instructions within a
given Jaspr into "counting instructions". That means instead of performing some
quantum gate, we increment an index in an array, which keeps track of how many
instructions of each type have been performed.

To do this, we implement the 

qrisp.jasp.interpreter_tools.interpreters.profiling_interpreter.py

Which handles the transformation logic of the Jaspr.
This file implements the interfaces to evaluating the transformed Jaspr.

"""

from functools import wraps
from typing import Any, Callable, NamedTuple, Tuple

from jax.tree_util import tree_flatten

from qrisp.jasp.evaluation_tools.jaspification import simulate_jaspr
from qrisp.jasp.interpreter_tools.interpreters.count_ops_metric import (
    extract_count_ops,
    get_count_ops_profiler,
)
from qrisp.jasp.interpreter_tools.interpreters.depth_metric import (
    extract_depth,
    get_depth_profiler,
    simulate_depth,
)
from qrisp.jasp.interpreter_tools.interpreters.num_qubits_metric import (
    extract_num_qubits,
    get_num_qubits_profiler,
    simulate_num_qubits,
)
from qrisp.jasp.interpreter_tools.interpreters.utilities import (
    always_one,
    always_zero,
    simulation,
)
from qrisp.jasp.jasp_expression import Jaspr


class MetricSpec(NamedTuple):
    """Specification of a metric to be computed via profiling."""

    build_profiler: Callable[[Jaspr, Callable], Tuple[Callable, Any]]
    extract_metric: Callable[[Tuple, Jaspr, Any], Any]
    simulate_fallback: Callable[[Jaspr, Any], Any]


METRIC_DISPATCH = {
    "count_ops": MetricSpec(
        build_profiler=get_count_ops_profiler,
        extract_metric=extract_count_ops,
        simulate_fallback=simulate_jaspr,
    ),
    "depth": MetricSpec(
        build_profiler=get_depth_profiler,
        extract_metric=extract_depth,
        simulate_fallback=simulate_depth,
    ),
    "num_qubits": MetricSpec(
        build_profiler=get_num_qubits_profiler,
        extract_metric=extract_num_qubits,
        simulate_fallback=simulate_num_qubits,
    ),
}


def _normalize_meas_behavior(meas_behavior) -> Callable:
    """Normalize the measurement behavior into a callable."""

    if isinstance(meas_behavior, str):
        if meas_behavior == "0":
            return always_zero
        if meas_behavior == "1":
            return always_one
        if meas_behavior == "sim":
            return simulation
        raise ValueError(
            f"Don't know how to compute required resources via method {meas_behavior}"
        )

    if callable(meas_behavior):
        return meas_behavior

    raise TypeError("meas_behavior must be a str or callable")


# TODO: Move each metric implementation into its dedicated module (already present).
# Keeping them here for now to avoid circular imports.


[docs] def count_ops(meas_behavior: str | Callable) -> Callable: """ Decorator to determine resources of large scale quantum computations. This decorator compiles the given Jasp-compatible function into a classical function computing the amount of each gates required. The decorated function will return a dictionary containing the operation counts. For many algorithms including classical feedback, the result of the measurements can heavily influence the required resources. To reflect this, users can specify the behavior of measurements during the computation of resources. The following strategies are available: * ``"0"`` - computes the resource as if measurements always return 0 * ``"1"`` - computes the resource as if measurements always return 1 * *callable* - allows the user to specify a random number generator (see examples) For more details on how the *callable* option can be used, consult the examples section. Finally it is also possible to call the Qrisp simulator to determine measurement behavior by providing ``"sim"``. This is of course much less scalable but in particular for algorithms involving repeat-until-success components, a necessary evil. Note that the ``"sim"`` option might return non-deterministic results, while the other methods do. .. warning:: It is currently not possible to estimate programs, which include a :ref:`kernelized <quantum_kernel>` function. Parameters ---------- meas_behavior : str or callable A string or callable indicating the behavior of the resource computation when measurements are performed. Available strings are ``"0"``, ``"1"``, and ``"sim"``. Returns ------- resource_estimation decorator : Callable A decorator, producing a function to computed the required resources. Examples -------- We compute the resources required to perform a large scale integer multiplication. :: from qrisp import count_ops, QuantumFloat, measure @count_ops(meas_behavior = "0") def main(i): a = QuantumFloat(i) b = QuantumFloat(i) c = a*b return measure(c) print(main(5)) # {'s': 45, 'x': 22, 't_dg': 98, 'cx': 510, 't': 96, 'h': 139, 'measure': 55} print(main(5000)) # {'t': 751506, 'h': 1127254, 'x': 2002, 's': 375750, 't_dg': 751508, 'cx': 4629255, 'measure': 752500} Note that even though the second computation contains more than 800 million gates, determining the resources takes less than 200ms, highlighting the scalability features of the Jasp infrastructure. **Modifying the measurement behavior via a random number generator** To specify the behavior, we specify an RNG function (for more details on what that means please check the `Jax documentation <https://docs.jax.dev/en/latest/jax.random.html>`_. This RNG takes as input a "key" and returns a boolean value. In this case, the return value will be uniformly distributed among True and False. :: from jax import random import jax.numpy as jnp from qrisp import QuantumFloat, measure, control, count_ops, x # Returns a uniformly distributed boolean def meas_behavior(key): return jnp.bool(random.randint(key, (1,), 0,1)[0]) @count_ops(meas_behavior = meas_behavior) def main(i): qv = QuantumFloat(2) meas_res = measure(qv) with control(meas_res == i): x(qv) return measure(qv) This script executes two measurements and based on the measurement outcome executes two X gates. We can now execute this resource computation with different values of ``i`` to see, which measurements return ``True`` with our given random-number generator (recall that this way of specifying the measurement behavior is fully deterministic). :: print(main(0)) # Yields: {'measure': 4, 'x': 2} print(main(1)) # Yields: {'measure': 4} print(main(2)) # Yields: {'measure': 4} print(main(3)) # Yields: {'measure': 4} From this we conclude that our RNG returned 0 for both of the initial measurements. For some algorithms (such as :ref:`RUS`) sampling the measurement result from a simple distribution won't cut it because the required ressource can be heavily influenced by measurement outcomes. For this matter it is also possible to perform a full simulation. Note that this simulation is no longer deterministic. :: @count_ops(meas_behavior = "sim") def main(i): qv = QuantumFloat(2) meas_res = measure(qv) with control(meas_res == i): x(qv) return measure(qv) print(main(0)) {'measure': 4, 'x': 2} print(main(1)) {'measure': 4} """ def count_ops_decorator(function): def ops_counter(*args): from qrisp.jasp import make_jaspr if not hasattr(function, "jaspr_dict"): function.jaspr_dict = {} signature = tuple(type(arg) for arg in args) shape_signature = tuple( arg.shape for arg in tree_flatten(args)[0] if hasattr(arg, "shape") ) hash_key = (signature, shape_signature, hash(meas_behavior)) if hash_key not in function.jaspr_dict: function.jaspr_dict[hash_key] = make_jaspr(function)(*args) return function.jaspr_dict[hash_key].count_ops( *args, meas_behavior=meas_behavior ) return ops_counter return count_ops_decorator
[docs] def depth(meas_behavior: str | Callable, max_qubits: int = 1024) -> Callable: """ Decorator to determine the depth of large scale quantum computations. This decorator compiles the given Jasp-compatible function into a classical function computing the circuit depth required. The decorated function returns an integer indicating the depth of the quantum computation. The depth is computed by tracking, for each qubit, the time at which it becomes available again after an operation. Multi-qubit gates increase the depth of all qubits they act on to the same value. Parameters ---------- meas_behavior : str or callable A string or callable indicating the behavior of the resource computation when measurements are performed. Available strings are ``"0"`` and ``"1"``. A callable must take a JAX PRNG key as input and return a boolean. max_qubits : int, optional The maximum number of qubits supported for depth computation. Default is 1024. Returns ------- depth decorator : Callable A decorator producing a function that computes the depth required. Examples -------- Let's consider a simple circuit: :: from qrisp import * @depth(meas_behavior="0") def circuit(n): qv = QuantumFloat(n) h(qv[0]) h(qv[1]) cx(qv[0], qv[1]) h(qv[0]) print(circuit(2)) # Output: 3 The first two Hadamards run in parallel (depth 1), the CNOT increases depth to 2, and the final Hadamard gives depth 3. Now, consider a circuit with measurement and classical control: :: @depth(meas_behavior="0") def circuit(n): qv = QuantumFloat(n) m = measure(qv[0]) with control(m == 0): h(qv[0]) x(qv[1]) h(qv[0]) with control(m == 1): cx(qv[0], qv[1]) h(qv[0]) x(qv[0]) print(circuit(2)) # Output: 2 The same circuit with ``meas_behavior="1"`` yields a depth of 3, because a different branch of the computation is taken. **Macro-gates and gate definitions** If a gate has a ``definition`` (for example a Toffoli gate implemented as a sequence of simpler gates), the `transpile` method is applied to the definition to determine the depth of the macro-gate. .. note:: Computing depth requires tracking qubit dependencies. As a result, compilation time for the depth metric can be noticeably slower for large circuits compared to ``count_ops``. This will be improved in future versions. However, the scalability offered by Jasp after the initial compilation is not affected. .. note:: The ``max_qubits`` parameter sets an upper limit on the number of qubits that can be handled for depth computation. This is necessary as JAX requires static shapes for JIT compilation. The default value of 1024 can be adjusted based on the expected number of qubits in the circuits to be analyzed. .. warning:: It is currently not possible to estimate programs, which include a :ref:`kernelized <quantum_kernel>` function. .. warning:: The depth metric an experimental feature and may not behave as expected in certain edge cases. - The memory management operations ``reset`` and ``delete`` are currently ignored. Qubits freed by these calls still count toward the ``max_qubits`` limit. - This metric can currently handle the slice operation correctly only when the lower bound of the slice is strictly smaller than the upper bound. """ def depth_decorator(function): def depth_counter(*args): from qrisp.jasp import make_jaspr if not hasattr(function, "jaspr_dict"): function.jaspr_dict = {} signature = tuple(type(arg) for arg in args) shape_signature = tuple( arg.shape for arg in tree_flatten(args)[0] if hasattr(arg, "shape") ) hash_key = (signature, shape_signature, hash(meas_behavior)) if hash_key not in function.jaspr_dict: function.jaspr_dict[hash_key] = make_jaspr(function)(*args) return function.jaspr_dict[hash_key].depth( *args, meas_behavior=meas_behavior, max_qubits=max_qubits ) return depth_counter return depth_decorator
[docs] def num_qubits(meas_behavior: str | Callable, max_allocations: int = 1000) -> Callable: """ Decorator to track qubit allocation and deallocation events during a quantum computation. This decorator compiles a Jasp-compatible quantum function into a resource-analysis function that tracks qubit allocation and deallocation events throughout the computation. An internal allocation counter is updated as follows: - increased whenever qubits are allocated (e.g., via ``QuantumVariable`` creation), - decreased whenever qubits are explicitly deleted (e.g., via ``qv.delete()``), The decorated function returns a dictionary containing information about all allocation and deallocation events. These are: - ``total_allocated``: the total number of qubits allocated during the computation. - ``total_deallocated``: the total number of qubits deallocated during the computation. - ``peak_allocations``: the maximum number of qubits allocated at any point during the computation. - ``finally_allocated``: the number of qubits still allocated at the end of the computation. See the examples below for more details on how to interpret these values. Parameters ---------- meas_behavior : str or callable A string or callable indicating the behavior of the resource computation when measurements are performed. Available strings are ``"0"`` and ``"1"``. A callable must take a JAX PRNG key as input and return a boolean. max_allocations : int, optional The maximum number of allocation/deallocation events supported for tracking. Default is 1000. This is necessary as JAX requires static shapes for JIT compilation. Returns ------- Callable A decorator producing a function that returns a dictionary containing aggregated statistics about allocation and deallocation events during the computation. Examples -------- Let's consider a simple circuit in which the number of allocated qubits depends on the measurement outcome: :: from qrisp import * @num_qubits(meas_behavior="0") def circuit(n1, n2, n3): qv = QuantumFloat(n1) m = measure(qv[0]) with control(m == 0): qv2 = QuantumFloat(n2) h(qv2[0]) with control(m == 1): qv3 = QuantumFloat(n3) h(qv3[0]) print(circuit(2, 3, 4)) # Output: # {'total_allocated': 5, 'total_deallocated': 0, # 'peak_allocations': 5, 'finally_allocated': 5} Here, the measurement of the first qubit determines whether we allocate 3 or 4 additional qubits. The output dictionary contains information about the total number of allocated qubits (5), the total number of deallocated qubits (0), the peak number of allocated qubits at any point during the computation (5), and the number of qubits still allocated at the end of the computation (5). If we change the measurement behavior to ``"1"``, we get a different output. Note that deallocation affects the final count: :: @num_qubits(meas_behavior="0") def circuit(n): qv = QuantumFloat(2 * n) h(qv[0]) qv.delete() qv = QuantumFloat(n) h(qv[0]) print(circuit(4)) # Output: # {'total_allocated': 12, 'total_deallocated': 8, # 'peak_allocations': 8, 'finally_allocated': 4} Here, we first allocate 8 qubits, then deallocate them, and finally allocate 4 more qubits. Let's see a final example with branching and deallocation: :: from qrisp import * @num_qubits(meas_behavior="1") def circuit(num_qubits_input): list_of_qvs = [] for i in range(2): qv = QuantumFloat(num_qubits_input) h(qv[i]) list_of_qvs.append(qv) qv_2 = QuantumFloat(1) h(qv_2[0]) m = measure(qv_2[0]) qv_2.delete() with control(m == 1): qv4 = QuantumFloat(10) h(qv4[0]) qv4.delete() for i in range(2): list_of_qvs[i].delete() print(circuit(8)) # Output: # {'total_allocated': 27, 'total_deallocated': 27, # 'peak_allocations': 26, 'finally_allocated': 0} In this example, the peak number of allocated qubits is different from the total allocated because ``qv_2`` is deleted before the subsequent allocation of ``qv4``. The final number of allocated qubits is 0 because all allocated qubits are eventually deallocated. .. warning:: Programs that include a :ref:`kernelized <quantum_kernel>` function cannot currently be analyzed. """ def num_qubits_decorator(function): def qubits_counter(*args): from qrisp.jasp import make_jaspr if not hasattr(function, "jaspr_dict"): function.jaspr_dict = {} signature = tuple(type(arg) for arg in args) shape_signature = tuple( arg.shape for arg in tree_flatten(args)[0] if hasattr(arg, "shape") ) hash_key = (signature, shape_signature, hash(meas_behavior)) if hash_key not in function.jaspr_dict: function.jaspr_dict[hash_key] = make_jaspr(function)(*args) return function.jaspr_dict[hash_key].num_qubits( *args, meas_behavior=meas_behavior, max_allocations=max_allocations ) return qubits_counter return num_qubits_decorator
def profile_jaspr( jaspr: Jaspr, mode: str, meas_behavior: str | Callable = "0", **kwargs: Any ) -> Callable: """ Profile a Jaspr according to a given metric mode. Parameters ---------- jaspr : Jaspr The Jaspr to be profiled. mode : str The profiling mode to be used. Currently supported modes are "depth", "count_ops", and "num_qubits". meas_behavior : str or callable, optional The measurement behavior to be used during profiling. Default is "0". **kwargs : Any Additional keyword arguments to be passed to the profiler builder. For example, `max_qubits` for depth profiling, or `max_allocations` for num_qubits profiling. Returns ------- Callable A function that computes the specified metric when called with the same arguments as the original Jaspr. """ meas_behavior_callable = _normalize_meas_behavior(meas_behavior) metric_spec = METRIC_DISPATCH[mode] if ( meas_behavior_callable.__name__ == "simulation" and metric_spec.simulate_fallback is not None ): @wraps(metric_spec.simulate_fallback) def simulation_wrapper(*args): return metric_spec.simulate_fallback(jaspr, *args, return_gate_counts=True) return simulation_wrapper # `profiler` is a function that computes the metric we are interested in. # `aux` is any auxiliary data that might be needed to reconstruct the metric # (for example the profiling dictionary for count_ops). profiler, aux = metric_spec.build_profiler(jaspr, meas_behavior_callable, **kwargs) @wraps(profiler) def profiler_wrapper(*args): args = tree_flatten(args)[0] res = profiler(*args) return metric_spec.extract_metric(res, jaspr, aux) return profiler_wrapper