Prefix Control#
The following functions expose program control features from Jax as Jasp-compatible functions. While Loops and Classical Control work fine for many situations, their syntactically convenient form prohibit some cases that might be relevant to your application. In particular it is impossible to use “carry values”, (ie. classical values that get computed during a loop or conditional) outside of the control structure. To realize this behavior you might be tempted to use the Jax-exposed functions fori_loop, while_loop and cond. However these do not properly track the Jasp internal quantum state, which is why we expose Jasp-compatible implementations of these functions.
Note
If you kernelized your code, the quantum state doesn’t need to be tracked, implying you can use the Jax versions.
- q_fori_loop(lower, upper, body_fun, init_val)[source]#
Jasp compatible version of jax.lax.fori_loop The parameters and semantics are the same as for the Jax version.
In particular the following loop is performed
def q_fori_loop(lower, upper, body_fun, init_val): val = init_val for i in range(lower, upper): val = body_fun(i, val) return val
- Parameters:
- lowerint or jax.core.Tracer
An integer representing the loop index lower bound (inclusive).
- upperint or jax.core.Tracer
An integer representing the loop index upper bound (exclusive).
- body_funcallable
The function describing the loop body.
- init_valobject
Some object to initialize the loop with.
- Returns:
- valobject
The return value of body_fun after the final iteration.
Examples
We write a dynamic loop that collects measurement values of a quantum qubits into an accumulator:
@jaspify def main(k): qf = QuantumFloat(6) def body_fun(i, val): acc, qf = val x(qf[i]) acc += measure(qf[i]) return acc, qf acc, qf = q_fori_loop(0, k, body_fun, (0, qf)) return acc, measure(qf) print(main(k)) # Yields: # (Array(5, dtype=int64), Array(31., dtype=float64))
- q_while_loop(cond_fun, body_fun, init_val)[source]#
Jasp compatible version of jax.lax.while_loop The parameters and semantics are the same as for the Jax version.
In particular the following loop is performed
def q_while_loop(cond_fun, body_fun, init_val): val = init_val while cond_fun(val): val = body_fun(val) return val
- Parameters:
- cond_funcallable
A function that evaluates the condition of the while loop. Must not contain any quantum operations.
- body_funcallable
A function describing the body of the loop.
- init_valobject
An object to initialize the loop.
- Returns:
- val
The result of
body_fun
after the last iteration.
- Raises:
- Exception
Tried to modify quantum state during while condition evaluation.
Examples
We write a dynamic loop that collects measurement values of a quantum qubits into an accumulator. Note that the accumulator variable is a carry value implying the loop could not be implemented using Loops.
from qrisp import * from qrisp.jasp import * @jaspify def main(k): qf = QuantumFloat(6) def body_fun(val): i, acc, qf = val x(qf[i]) acc += measure(qf[i]) i += 1 return i, acc, qf def cond_fun(val): return val[0] < 5 i, acc, qf = q_while_loop(cond_fun, body_fun, (0, 0, qf)) return acc, measure(qf) print(main(6)) # Yields # (Array(5, dtype=int64), Array(31., dtype=float64))
- q_cond(pred, true_fun, false_fun, *operands)[source]#
Jasp compatible version of jax.lax.cond The parameters and semantics are the same as for the Jax version.
Performs the following semantics:
def q_cond(pred, true_fun, false_fun, *operands): if pred: return true_fun(*operands) else: return false_fun(*operands)
- Parameters:
- predbool or jax.core.Tracer
A boolean value, deciding which function gets executed.
- true_funcallable
The function that is executed when
pred
is True.- false_funcallable
The function that is executed when
pred
is False.- *operandstuple
The input values for both functions.
- Returns:
- object
The return value of the respective function.
Examples
We write a script that brings a QuantumBool into superpostion and subsequently measures it. If the measurement result is
False
we flip it such that in the end, the bool will always be in the \(\ket{\text{True}}\) state.from qrisp import * from qrisp.jasp import * @jaspify def main(): def false_fun(qbl): qbl.flip() return qbl def true_fun(qbl): return qbl qbl = QuantumBool() h(qbl) pred = measure(qbl) qbl = q_cond(pred, true_fun, false_fun, qbl) return measure(qbl) print(main()) # Yields: # True