Prefix Control#

The following functions expose program control features from Jax as Jasp-compatible functions. While Loops and Classical Control work fine for many situations, their syntactically convenient form prohibit some cases that might be relevant to your application. In particular it is impossible to use “carry values”, (ie. classical values that get computed during a loop or conditional) outside of the control structure. To realize this behavior you might be tempted to use the Jax-exposed functions fori_loop, while_loop and cond. However these do not properly track the Jasp internal quantum state, which is why we expose Jasp-compatible implementations of these functions.

Note

If you kernelized your code, the quantum state doesn’t need to be tracked, implying you can use the Jax versions.

q_fori_loop(lower, upper, body_fun, init_val)[source]#

Jasp compatible version of jax.lax.fori_loop The parameters and semantics are the same as for the Jax version.

In particular the following loop is performed

def q_fori_loop(lower, upper, body_fun, init_val):
    val = init_val
    for i in range(lower, upper):
        val = body_fun(i, val)
    return val
Parameters:
lowerint or jax.core.Tracer

An integer representing the loop index lower bound (inclusive).

upperint or jax.core.Tracer

An integer representing the loop index upper bound (exclusive).

body_funcallable

The function describing the loop body.

init_valobject

Some object to initialize the loop with.

Returns:
valobject

The return value of body_fun after the final iteration.

Examples

We write a dynamic loop that collects measurement values of a quantum qubits into an accumulator:

@jaspify
def main(k):

    qf = QuantumFloat(6)

    def body_fun(i, val):
        acc, qf = val
        x(qf[i])
        acc += measure(qf[i])
        return acc, qf

    acc, qf = q_fori_loop(0, k, body_fun, (0, qf))

    return acc, measure(qf)

print(main(k))
# Yields:
# (Array(5, dtype=int64), Array(31., dtype=float64))
q_while_loop(cond_fun, body_fun, init_val)[source]#

Jasp compatible version of jax.lax.while_loop The parameters and semantics are the same as for the Jax version.

In particular the following loop is performed

def q_while_loop(cond_fun, body_fun, init_val):
    val = init_val
    while cond_fun(val):
        val = body_fun(val)
    return val
Parameters:
cond_funcallable

A function that evaluates the condition of the while loop. Must not contain any quantum operations.

body_funcallable

A function describing the body of the loop.

init_valobject

An object to initialize the loop.

Returns:
val

The result of body_fun after the last iteration.

Raises:
Exception

Tried to modify quantum state during while condition evaluation.

Examples

We write a dynamic loop that collects measurement values of a quantum qubits into an accumulator. Note that the accumulator variable is a carry value implying the loop could not be implemented using Loops.

from qrisp import *
from qrisp.jasp import *

@jaspify
def main(k):

    qf = QuantumFloat(6)

    def body_fun(val):
        i, acc, qf = val
        x(qf[i])
        acc += measure(qf[i])
        i += 1
        return i, acc, qf

    def cond_fun(val):
        return val[0] < 5

    i, acc, qf = q_while_loop(cond_fun, body_fun, (0, 0, qf))

    return acc, measure(qf)

print(main(6))
# Yields
# (Array(5, dtype=int64), Array(31., dtype=float64))
q_cond(pred, true_fun, false_fun, *operands)[source]#

Jasp compatible version of jax.lax.cond The parameters and semantics are the same as for the Jax version.

Performs the following semantics:

def q_cond(pred, true_fun, false_fun, *operands):
    if pred:
        return true_fun(*operands)
    else:
        return false_fun(*operands)
Parameters:
predbool or jax.core.Tracer

A boolean value, deciding which function gets executed.

true_funcallable

The function that is executed when pred is True.

false_funcallable

The function that is executed when pred is False.

*operandstuple

The input values for both functions.

Returns:
object

The return value of the respective function.

Examples

We write a script that brings a QuantumBool into superpostion and subsequently measures it. If the measurement result is False we flip it such that in the end, the bool will always be in the \(\ket{\text{True}}\) state.

from qrisp import *
from qrisp.jasp import *

@jaspify
def main():

    def false_fun(qbl):
        qbl.flip()
        return qbl

    def true_fun(qbl):
        return qbl

    qbl = QuantumBool()
    h(qbl)
    pred = measure(qbl)

    qbl = q_cond(pred, 
                 true_fun, 
                 false_fun, 
                 qbl)

    return measure(qbl)

print(main())
# Yields:
# True