Repeat-Until-Success#
- RUS(*trial_function, **jit_kwargs)[source]#
Decorator to deploy repeat-until-success (RUS) components. At the core, RUS repeats a given quantum subroutine followed by a qubit measurement until the measurement returns the value
1
. This step is prevalent in many important algorithms, among them the HHL algorithm or the LCU procedure.Within Jasp, RUS steps can be realized by providing the quantum subroutine as a “trial function”, which returns a boolean value (the repetition condition) and possibly other return values.
It is important to note that the trial function can not receive quantum arguments. This is because after each trial, a new copy of these arguments would be required to perform the next iteration, which is prohibited by the no-clone theorem. It is however legal to provide classical arguments.
- Parameters:
- trial_functioncallable
A function returning a boolean value as the first return value. More return values are possible.
- static_argnumsint or list[int], optional
A list of integers specifying which arguments are considered static in the sense of jax.jit. The first argument is indicated by 1, the second by 2, etc. The default is
[]
.- static_argnamesstr or list[str], optional
A list of strings specifying which arguments are considered static in the sense of jax.jit. The default is
[]
.
- Returns:
- callable
A function that performs the RUS protocol with the trial function. The return values of this function are the return values of the trial function WITHOUT the boolean value.
Examples
To demonstrate the RUS behavior, we initialize a GHZ state
\[\ket{\psi} = \frac{\ket{00000} + \ket{11111}}{\sqrt{2}}\]and measure the first qubit into a boolean value. This will be the value to cancel the repetition. This will collapse the GHZ state into either \(\ket{00000}\) (which will cause a new repetition) or \(\ket{11111}\), which cancels the loop. After the repetition is canceled we are therefore guaranteed to have the latter state.
from qrisp.jasp import RUS, make_jaspr from qrisp import QuantumFloat, h, cx, measure @RUS def rus_trial_function(): qf = QuantumFloat(5) h(qf[0]) for i in range(1, 5): cx(qf[0], qf[i]) cancelation_bool = measure(qf[0]) return cancelation_bool, qf def call_RUS_example(): qf = rus_trial_function() return measure(qf)
Create the
jaspr
and simulate:jaspr = make_jaspr(call_RUS_example)() print(jaspr()) # Yields, 31 which is the decimal version of 11111
Static arguments
To demonstrate the specification of static arguments, we will realize implement a simple linear combination of unitaries.
Our implementation initializes a state of the form
\[\left( \sum_{i = 0}^N c_i U_i \right) \ket{0}.\]We achieve this by specifying a set of unitaries \(U_i\) in the form of a tuple of functions, each processing a QuantumFloat.
The coefficients \(c_i\) are specified through a function preparing the state
\[\ket{\psi} = \sum_{i = 0}^N c_i \ket{i}\]For the state preparation function we specify two options to experiment with. A two qubit uniform superposition and a function that brings only the first qubit into superpostion.
def state_prep_full(qv): h(qv[0]) h(qv[1]) def state_prep_half(qv): h(qv[0])
For the first one we have \(c_0 = c_1 = c_2 = c_3 = \sqrt{0.25}\). The second one gives \(c_0 = c_1 = \sqrt{0.5}\) and \(c_2 = c_3 = 0\).
The next step is to define the unitaries \(U_i\) in the form of a tuple of functions.
from qrisp.jasp import * from qrisp import * def case_function_0(x): x += 3 def case_function_1(x): x += 4 def case_function_2(x): x += 5 def case_function_3(x): x += 6 case_functions = (case_function_0, case_function_1, case_function_2, case_function_3)
These functions each represent the unitary:
\[U_i \ket{0} = \ket{i+3}\]Executing a linear combination of unitaries therefore gives
\[\left( \sum_{i = 0}^N c_i U_i \right) \ket{0} = \sum_{i = 0}^N c_i \ket{i+3}\]Now we implement the LCU procedure.
# Specify the corresponding arguments of the block encoding as "static", # i.e. compile time constants. @RUS(static_argnums = [2,3]) def block_encoding(return_size, state_preparation, case_functions): # This QuantumFloat will be returned qf = QuantumFloat(return_size) # Specify the QuantumVariable that indicates, which # case to execute n = int(np.ceil(np.log2(len(case_functions)))) case_indicator = QuantumFloat(n) # Turn into a list of qubits case_indicator_qubits = [case_indicator[i] for i in range(n)] # Perform the LCU protocoll with conjugate(state_preparation)(case_indicator): for i in range(len(case_functions)): with control(case_indicator_qubits, ctrl_state = i): case_functions[i](qf) # Compute the success condition success_bool = (measure(case_indicator) == 0) return success_bool, qf
Finally, evaluate via the terminal_sampling feature:
@terminal_sampling def main(): return block_encoding(4, state_prep_full, case_functions) print(main()) # Yields: {3.0: 0.25, 4.0: 0.25, 5.0: 0.25, 6.0: 0.25}
Evaluate the other state preparation function
@terminal_sampling def main(): return block_encoding(4, state_prep_half, case_functions) print(main()) # Yields: {3.0: 0.5, 4.0: 0.5}
As expected, the full state preparation function yields a state proportional to
\[\ket{3} + \ket{4} + \ket{5} + \ket{6}.\]The second state preparation gives us
\[\ket{3} + \ket{4}.\]