Source code for qrisp.jasp.program_control.rus

"""
********************************************************************************
* Copyright (c) 2025 the Qrisp authors
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* This Source Code may also be made available under the following Secondary
* Licenses when the conditions for such availability set forth in the Eclipse
* Public License, v. 2.0 are satisfied: GNU General Public License, version 2
* with the GNU Classpath Exception which is
* available at https://www.gnu.org/software/classpath/license.html.
*
* SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0
********************************************************************************
"""

import inspect

from jax.lax import while_loop, cond
import jax
import jax.numpy as jnp

from qrisp.circuit import XGate
from qrisp.jasp import (
    TracingQuantumSession,
    AbstractQubitArray,
    DynamicQubitArray,
    qache,
)
from qrisp.jasp.primitives import (
    Measurement_p,
    OperationPrimitive,
    get_qubit_p,
    get_size_p,
    delete_qubits_p,
    reset_p,
)


[docs] def RUS(*trial_function, **jit_kwargs): r""" Decorator to deploy repeat-until-success (RUS) components. At the core, RUS repeats a given quantum subroutine followed by a qubit measurement until the measurement returns the value ``1``. This step is prevalent in many important algorithms, among them the `HHL algorithm <https://arxiv.org/abs/0811.3171>`_ or the `LCU procedure <https://arxiv.org/abs/1202.5822>`_. Within Jasp, RUS steps can be realized by providing the quantum subroutine as a "trial function", which returns a boolean value (the repetition condition) and possibly other return values. It is important to note that the trial function can not receive quantum arguments. This is because after each trial, a new copy of these arguments would be required to perform the next iteration, which is prohibited by the no-clone theorem. It is however legal to provide classical arguments. Parameters ---------- trial_function : callable A function returning a boolean value as the first return value. More return values are possible. static_argnums : int or list[int], optional A list of integers specifying which arguments are considered static in the sense of `jax.jit <https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html>`_. The first argument is indicated by 1, the second by 2, etc. The default is ``[]``. static_argnames : str or list[str], optional A list of strings specifying which arguments are considered static in the sense of `jax.jit <https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html>`_. The default is ``[]``. Returns ------- callable A function that performs the RUS protocol with the trial function. The return values of this function are the return values of the trial function WITHOUT the boolean value. Examples -------- To demonstrate the RUS behavior, we initialize a GHZ state .. math:: \ket{\psi} = \frac{\ket{00000} + \ket{11111}}{\sqrt{2}} and measure the first qubit into a boolean value. This will be the value to cancel the repetition. This will collapse the GHZ state into either $\ket{00000}$ (which will cause a new repetition) or $\ket{11111}$, which cancels the loop. After the repetition is canceled we are therefore guaranteed to have the latter state. :: from qrisp.jasp import RUS, make_jaspr from qrisp import QuantumFloat, h, cx, measure @RUS def rus_trial_function(): qf = QuantumFloat(5) h(qf[0]) for i in range(1, 5): cx(qf[0], qf[i]) cancelation_bool = measure(qf[0]) return cancelation_bool, qf def call_RUS_example(): qf = rus_trial_function() return measure(qf) Create the ``jaspr`` and simulate: :: jaspr = make_jaspr(call_RUS_example)() print(jaspr()) # Yields, 31 which is the decimal version of 11111 **Static arguments** To demonstrate the specification of static arguments, we will realize implement a simple `linear combination of unitaries <https://arxiv.org/abs/1202.5822>`_. Our implementation initializes a state of the form .. math:: \left( \sum_{i = 0}^N c_i U_i \right) \ket{0}. We achieve this by specifying a set of unitaries $U_i$ in the form of a tuple of functions, each processing a :ref:`QuantumFloat`. The coefficients $c_i$ are specified through a function preparing the state .. math:: \ket{\psi} = \sum_{i = 0}^N c_i \ket{i} For the state preparation function we specify two options to experiment with. A two qubit uniform superposition and a function that brings only the first qubit into superpostion. :: def state_prep_full(qv): h(qv[0]) h(qv[1]) def state_prep_half(qv): h(qv[0]) For the first one we have $c_0 = c_1 = c_2 = c_3 = \sqrt{0.25}$. The second one gives $c_0 = c_1 = \sqrt{0.5}$ and $c_2 = c_3 = 0$. The next step is to define the unitaries $U_i$ in the form of a tuple of functions. :: from qrisp.jasp import * from qrisp import * def case_function_0(x): x += 3 def case_function_1(x): x += 4 def case_function_2(x): x += 5 def case_function_3(x): x += 6 case_functions = (case_function_0, case_function_1, case_function_2, case_function_3) These functions each represent the unitary: .. math:: U_i \ket{0} = \ket{i+3} Executing a linear combination of unitaries therefore gives .. math:: \left( \sum_{i = 0}^N c_i U_i \right) \ket{0} = \sum_{i = 0}^N c_i \ket{i+3} Now we implement the LCU procedure. :: # Specify the corresponding arguments of the block encoding as "static", # i.e. compile time constants. @RUS(static_argnums = [2,3]) def block_encoding(return_size, state_preparation, case_functions): # This QuantumFloat will be returned qf = QuantumFloat(return_size) # Specify the QuantumVariable that indicates, which # case to execute n = int(np.ceil(np.log2(len(case_functions)))) case_indicator = QuantumFloat(n) # Turn into a list of qubits case_indicator_qubits = [case_indicator[i] for i in range(n)] # Perform the LCU protocoll with conjugate(state_preparation)(case_indicator): for i in range(len(case_functions)): with control(case_indicator_qubits, ctrl_state = i): case_functions[i](qf) # Compute the success condition success_bool = (measure(case_indicator) == 0) return success_bool, qf Finally, evaluate via the :ref:`terminal_sampling <terminal_sampling>` feature: :: @terminal_sampling def main(): return block_encoding(4, state_prep_full, case_functions) print(main()) # Yields: {3.0: 0.25, 4.0: 0.25, 5.0: 0.25, 6.0: 0.25} Evaluate the other state preparation function :: @terminal_sampling def main(): return block_encoding(4, state_prep_half, case_functions) print(main()) # Yields: {3.0: 0.5, 4.0: 0.5} As expected, the full state preparation function yields a state proportional to .. math:: \ket{3} + \ket{4} + \ket{5} + \ket{6}. The second state preparation gives us .. math:: \ket{3} + \ket{4}. """ if len(trial_function) == 0: return lambda x: RUS(x, **jit_kwargs) else: trial_function = trial_function[0] # The idea for implementing this feature is to execute the function once qached # to collect the output QuantumVariable object. # From the infered output signature the q_while_loop is constructed def return_function(*trial_args): from qrisp.jasp import q_while_loop, q_cond from qrisp.core import recursive_qv_search, reset abs_qs = TracingQuantumSession.get_instance() initial_gc_mode = abs_qs.gc_mode # Set the garbage collection mode to temporarily auto to collect # any ancillas that have not been deleted. abs_qs.gc_mode = "auto" # Execute the function qached_function = qache(trial_function, **jit_kwargs) first_iter_res = qached_function(*trial_args) abs_qs.gc_mode = initial_gc_mode # Filter out the static arguments if "static_argnums" in jit_kwargs: static_argnums = jit_kwargs["static_argnums"] if isinstance(static_argnums, int): static_argnums = [static_argnums] else: static_argnums = [] if "static_argnames" in jit_kwargs: argname_list = inspect.getfullargspec(trial_function) for i in range(len(argname_list)): if argname_list[i] in jit_kwargs["static_argnames"]: static_argnums.append(i) dynamic_args = [] for i in range(len(trial_args)): if i not in static_argnums: dynamic_args.append(trial_args[i]) n_arg_vals = len(dynamic_args) # Next we construct the body of the loop # The q_while_loop receives a tuple of arguments and # also returns a tuple with the same signature. # We therefore combine the results of the first iteration with # the arguments to execute the loop. combined_args = tuple(list(dynamic_args) + list(first_iter_res)) # This is the body function of the while loop def body_fun(args): # The first step is to reset and delete the results # from the previous iteration qv_results = recursive_qv_search(args[n_arg_vals:]) abs_qs = TracingQuantumSession.get_instance() for qv in qv_results: abs_qs.register_qv(qv, None) reset(qv) qv.delete() # We now construct the arguments for the function call of # the current iteration. # For this, we combine the dynamic arguments with the static # arguments. dynamic_args = list(args[:n_arg_vals]) new_trial_args = [] for i in range(len(trial_args)): if i not in static_argnums: new_trial_args.append(dynamic_args.pop(0)) else: new_trial_args.append(trial_args[i]) # Set the garbage collection mode to auto and call the function abs_qs.gc_mode = "auto" trial_res = qached_function(*new_trial_args) abs_qs.gc_mode = initial_gc_mode # Update the tuple with initial args and the new results combined_args = tuple(list(args[:n_arg_vals]) + list(trial_res)) return combined_args # This is the loop cancelation condition def cond_fun(val): # The loop cancelation index is located at the second position of the # return value tuple return ~val[n_arg_vals] # We now evaluate the loop # If the first iteration was already successful, we simply return the results # To realize this behavior we use a q_cond primitive def true_fun(combined_args): return combined_args # If the first iteration was not successfull, we start the loop def false_fun(combined_args): # Here is the while_loop return q_while_loop(cond_fun, body_fun, init_val=combined_args) # Evaluate everything combined_res = q_cond(first_iter_res[0], true_fun, false_fun, combined_args) # Return the results if len(first_iter_res) == 2: return combined_res[n_arg_vals+1] else: return combined_res[n_arg_vals+1:] return return_function