Prefix Control#

The following functions expose program control features from Jax as Jasp-compatible functions. While Loops and Classical Control work fine for many situations, their syntactically convenient form prohibit some cases that might be relevant to your application. In particular it is impossible to use “carry values”, (ie. classical values that get computed during a loop or conditional) outside of the control structure. To realize this behavior you might be tempted to use the Jax-exposed functions fori_loop, while_loop and cond. However these do not properly track the Jasp internal quantum state, which is why we expose Jasp-compatible implementations of these functions.

Note

If you kernelized your code, the quantum state doesn’t need to be tracked, implying you can use the Jax versions.

q_fori_loop(lower, upper, body_fun, init_val)[source]#

Jasp compatible version of jax.lax.fori_loop. The parameters and semantics are the same as for the Jax version.

In particular the following loop is performed

def q_fori_loop(lower, upper, body_fun, init_val):
    val = init_val
    for i in range(lower, upper):
        val = body_fun(i, val)
    return val
Parameters:
lowerint or jax.core.Tracer

An integer representing the loop index lower bound (inclusive).

upperint or jax.core.Tracer

An integer representing the loop index upper bound (exclusive).

body_funcallable

The function describing the loop body.

init_valobject

Some object to initialize the loop with.

Returns:
valobject

The return value of body_fun after the final iteration.

Examples

We write a dynamic loop that collects measurement values of a quantum qubits into an accumulator:

@jaspify
def main(k):

    qf = QuantumFloat(6)

    def body_fun(i, val):
        acc, qf = val
        x(qf[i])
        acc += measure(qf[i])
        return acc, qf

    acc, qf = q_fori_loop(0, k, body_fun, (0, qf))

    return acc, measure(qf)

print(main(k))
# Yields:
# (Array(5, dtype=int64), Array(31., dtype=float64))
q_while_loop(cond_fun, body_fun, init_val)[source]#

Jasp compatible version of jax.lax.while_loop. The parameters and semantics are the same as for the Jax version.

In particular the following loop is performed

def q_while_loop(cond_fun, body_fun, init_val):
    val = init_val
    while cond_fun(val):
        val = body_fun(val)
    return val
Parameters:
cond_funcallable

A function that evaluates the condition of the while loop. Must not contain any quantum operations.

body_funcallable

A function describing the body of the loop.

init_valobject

An object to initialize the loop.

Returns:
val

The result of body_fun after the last iteration.

Raises:
Exception

Tried to modify quantum state during while condition evaluation.

Examples

We write a dynamic loop that collects measurement values of a quantum qubits into an accumulator. Note that the accumulator variable is a carry value implying the loop could not be implemented using Loops.

from qrisp import *
from qrisp.jasp import *

@jaspify
def main(k):

    qf = QuantumFloat(6)

    def body_fun(val):
        i, acc, qf = val
        x(qf[i])
        acc += measure(qf[i])
        i += 1
        return i, acc, qf

    def cond_fun(val):
        return val[0] < 5

    i, acc, qf = q_while_loop(cond_fun, body_fun, (0, 0, qf))

    return acc, measure(qf)

print(main(6))
# Yields
# (Array(5, dtype=int64), Array(31., dtype=float64))
q_cond(pred, true_fun, false_fun, *operands)[source]#

Jasp compatible version of jax.lax.cond. The parameters and semantics are the same as for the Jax version.

Performs the following semantics:

def q_cond(pred, true_fun, false_fun, *operands):
    if pred:
        return true_fun(*operands)
    else:
        return false_fun(*operands)
Parameters:
predbool or jax.core.Tracer

A boolean value, deciding which function gets executed.

true_funcallable

The function that is executed when pred is True.

false_funcallable

The function that is executed when pred is False.

*operandstuple

The input values for both functions.

Returns:
object

The return value of the respective function.

Examples

We write a script that brings a QuantumBool into superpostion and subsequently measures it. If the measurement result is False we flip it such that in the end, the bool will always be in the \(\ket{\text{True}}\) state.

from qrisp import *
from qrisp.jasp import *

@jaspify
def main():

    def false_fun(qbl):
        qbl.flip()
        return qbl

    def true_fun(qbl):
        return qbl

    qbl = QuantumBool()
    h(qbl)
    pred = measure(qbl)

    qbl = q_cond(pred,
                 true_fun,
                 false_fun,
                 qbl)

    return measure(qbl)

print(main())
# Yields:
# True
q_switch(index, branches, *operands, branch_amount=None, method='auto')[source]#

Classical index

Jasp compatible version of jax.lax.switch. The parameters and semantics are the same as for the Jax version.

Performs the following semantics:

def q_switch(index, branches, *operands):
    return branches[index](*operands)

Quantum index

Executes a quantum switch - case statement distinguishing between given in-place functions.

Implements the operation

\[\text{SELECT} = \sum_i \ket{i}\bra{i} \otimes U_i\]

for unitaries (branches) \(U_i\), applying the \(i\)-th unitary conditioned on the index variable being in state \(\ket{i}\).

Parameters:
indexint or jax.core.Tracer or QuantumVariable or list[Qubit]

An integer value, deciding which function gets executed.

brancheslist[callable] or callable

List of functions to be executed based on index or a single function that takes the index as first argument.

*operandstuple

The input values for whichever function is applied.

branch_amountint, optional

The amount of branches. Only needed if index is a QuantumVariable and branches is a function. Is automatically inferred from the length of branches if it is a list.

methodstr, optional

Only needed if index is a QuantumVariable. The method used to implement the quantum switch. Can be "auto", "sequential", "parallel", or "tree". Default is "auto". Method "tree" uses balanced binary trees. Method "parallel" is exponentially faster but requires more qubits.

Returns:
object

The return value of the respective function.

Examples

Classical index

We write a script that brings a QuantumFloat into superpostion and subsequently measures it. If the measurement result is k we add 3-k such that in the end, the float will always be in the \(\ket{\text{3}}\) state.

from qrisp import *
from qrisp.jasp import *
import jax.numpy as jnp

@jaspify
def main():

    def f0(x): x += 3
    def f1(x): x += 2
    def f2(x): x += 1
    def f3(x): pass
    branches = [f0, f1, f2, f3]

    operand = QuantumFloat(2)
    h(operand)
    index = jnp.int32(measure(operand))

    q_switch(index, branches, operand)
    return measure(operand)

print(main())
# 3.0

Quantum index

We write a script that uses a QuantumFloat as index to select different operations on another operand QuantumFloat. The index variable is put into superposition such that all branches are executed in superposition.

from qrisp import *
from qrisp.jasp import *

@terminal_sampling
def main():

    def f0(x): x += 1
    def f1(x): x += 2
    def f2(x): pass
    def f3(x): h(x[1])
    branches = [f0, f1, f2, f3]

    operand = QuantumFloat(4)
    operand[:] = 1
    index = QuantumFloat(2)
    h(index)

    q_switch(index, branches, operand)
    return index, operand

print(main())
# {(0.0, 2.0): 0.25000000372529035, (1.0, 3.0): 0.25000000372529035,
# (2.0, 1.0): 0.25000000372529035, (3.0, 1.0): 0.12499999441206447,
# (3.0, 3.0): 0.12499999441206447}