Prefix Control#
The following functions expose program control features from Jax as Jasp-compatible functions. While Loops and Classical Control work fine for many situations, their syntactically convenient form prohibit some cases that might be relevant to your application. In particular it is impossible to use “carry values”, (ie. classical values that get computed during a loop or conditional) outside of the control structure. To realize this behavior you might be tempted to use the Jax-exposed functions fori_loop, while_loop and cond. However these do not properly track the Jasp internal quantum state, which is why we expose Jasp-compatible implementations of these functions.
Note
If you kernelized your code, the quantum state doesn’t need to be tracked, implying you can use the Jax versions.
- q_fori_loop(lower, upper, body_fun, init_val)[source]#
Jasp compatible version of jax.lax.fori_loop. The parameters and semantics are the same as for the Jax version.
In particular the following loop is performed
def q_fori_loop(lower, upper, body_fun, init_val): val = init_val for i in range(lower, upper): val = body_fun(i, val) return val
- Parameters:
- lowerint or jax.core.Tracer
An integer representing the loop index lower bound (inclusive).
- upperint or jax.core.Tracer
An integer representing the loop index upper bound (exclusive).
- body_funcallable
The function describing the loop body.
- init_valobject
Some object to initialize the loop with.
- Returns:
- valobject
The return value of body_fun after the final iteration.
Examples
We write a dynamic loop that collects measurement values of a quantum qubits into an accumulator:
@jaspify def main(k): qf = QuantumFloat(6) def body_fun(i, val): acc, qf = val x(qf[i]) acc += measure(qf[i]) return acc, qf acc, qf = q_fori_loop(0, k, body_fun, (0, qf)) return acc, measure(qf) print(main(k)) # Yields: # (Array(5, dtype=int64), Array(31., dtype=float64))
- q_while_loop(cond_fun, body_fun, init_val)[source]#
Jasp compatible version of jax.lax.while_loop. The parameters and semantics are the same as for the Jax version.
In particular the following loop is performed
def q_while_loop(cond_fun, body_fun, init_val): val = init_val while cond_fun(val): val = body_fun(val) return val
- Parameters:
- cond_funcallable
A function that evaluates the condition of the while loop. Must not contain any quantum operations.
- body_funcallable
A function describing the body of the loop.
- init_valobject
An object to initialize the loop.
- Returns:
- val
The result of
body_funafter the last iteration.
- Raises:
- Exception
Tried to modify quantum state during while condition evaluation.
Examples
We write a dynamic loop that collects measurement values of a quantum qubits into an accumulator. Note that the accumulator variable is a carry value implying the loop could not be implemented using Loops.
from qrisp import * from qrisp.jasp import * @jaspify def main(k): qf = QuantumFloat(6) def body_fun(val): i, acc, qf = val x(qf[i]) acc += measure(qf[i]) i += 1 return i, acc, qf def cond_fun(val): return val[0] < 5 i, acc, qf = q_while_loop(cond_fun, body_fun, (0, 0, qf)) return acc, measure(qf) print(main(6)) # Yields # (Array(5, dtype=int64), Array(31., dtype=float64))
- q_cond(pred, true_fun, false_fun, *operands)[source]#
Jasp compatible version of jax.lax.cond. The parameters and semantics are the same as for the Jax version.
Performs the following semantics:
def q_cond(pred, true_fun, false_fun, *operands): if pred: return true_fun(*operands) else: return false_fun(*operands)
- Parameters:
- predbool or jax.core.Tracer
A boolean value, deciding which function gets executed.
- true_funcallable
The function that is executed when
predis True.- false_funcallable
The function that is executed when
predis False.- *operandstuple
The input values for both functions.
- Returns:
- object
The return value of the respective function.
Examples
We write a script that brings a QuantumBool into superpostion and subsequently measures it. If the measurement result is
Falsewe flip it such that in the end, the bool will always be in the \(\ket{\text{True}}\) state.from qrisp import * from qrisp.jasp import * @jaspify def main(): def false_fun(qbl): qbl.flip() return qbl def true_fun(qbl): return qbl qbl = QuantumBool() h(qbl) pred = measure(qbl) qbl = q_cond(pred, true_fun, false_fun, qbl) return measure(qbl) print(main()) # Yields: # True
- q_switch(index, branches, *operands, branch_amount=None, method='auto')[source]#
Classical index
Jasp compatible version of jax.lax.switch. The parameters and semantics are the same as for the Jax version.
Performs the following semantics:
def q_switch(index, branches, *operands): return branches[index](*operands)
Quantum index
Executes a quantum switch - case statement distinguishing between given in-place functions.
Implements the operation
\[\text{SELECT} = \sum_i \ket{i}\bra{i} \otimes U_i\]for unitaries (branches) \(U_i\), applying the \(i\)-th unitary conditioned on the index variable being in state \(\ket{i}\).
- Parameters:
- indexint or jax.core.Tracer or QuantumVariable or list[Qubit]
An integer value, deciding which function gets executed.
- brancheslist[callable] or callable
List of functions to be executed based on
indexor a single function that takes the index as first argument.- *operandstuple
The input values for whichever function is applied.
- branch_amountint, optional
The amount of branches. Only needed if
indexis a QuantumVariable andbranchesis a function. Is automatically inferred from the length ofbranchesif it is a list.- methodstr, optional
Only needed if
indexis a QuantumVariable. The method used to implement the quantum switch. Can be"auto","sequential","parallel", or"tree". Default is"auto". Method"tree"uses balanced binary trees. Method"parallel"is exponentially faster but requires more qubits.
- Returns:
- object
The return value of the respective function.
Examples
Classical index
We write a script that brings a QuantumFloat into superpostion and subsequently measures it. If the measurement result is
kwe add3-ksuch that in the end, the float will always be in the \(\ket{\text{3}}\) state.from qrisp import * from qrisp.jasp import * import jax.numpy as jnp @jaspify def main(): def f0(x): x += 3 def f1(x): x += 2 def f2(x): x += 1 def f3(x): pass branches = [f0, f1, f2, f3] operand = QuantumFloat(2) h(operand) index = jnp.int32(measure(operand)) q_switch(index, branches, operand) return measure(operand) print(main()) # 3.0
Quantum index
We write a script that uses a QuantumFloat as index to select different operations on another operand QuantumFloat. The index variable is put into superposition such that all branches are executed in superposition.
from qrisp import * from qrisp.jasp import * @terminal_sampling def main(): def f0(x): x += 1 def f1(x): x += 2 def f2(x): pass def f3(x): h(x[1]) branches = [f0, f1, f2, f3] operand = QuantumFloat(4) operand[:] = 1 index = QuantumFloat(2) h(index) q_switch(index, branches, operand) return index, operand print(main()) # {(0.0, 2.0): 0.25000000372529035, (1.0, 3.0): 0.25000000372529035, # (2.0, 1.0): 0.25000000372529035, (3.0, 1.0): 0.12499999441206447, # (3.0, 3.0): 0.12499999441206447}