Quantum Switch Case#

q_switch(index, branches, *operands, branch_amount=None, method='auto')[source]

Classical index

Jasp compatible version of jax.lax.switch. The parameters and semantics are the same as for the Jax version.

Performs the following semantics:

def q_switch(index, branches, *operands):
    return branches[index](*operands)

Quantum index

Executes a quantum switch - case statement distinguishing between given in-place functions.

Implements the operation

\[\text{SELECT} = \sum_i \ket{i}\bra{i} \otimes U_i\]

for unitaries (branches) \(U_i\), applying the \(i\)-th unitary conditioned on the index variable being in state \(\ket{i}\).

Parameters:
indexint or jax.core.Tracer or QuantumVariable or list[Qubit]

An integer value, deciding which function gets executed.

brancheslist[callable] or callable

List of functions to be executed based on index or a single function that takes the index as first argument.

*operandstuple

The input values for whichever function is applied.

branch_amountint, optional

The amount of branches. Only needed if index is a QuantumVariable and branches is a function. Is automatically inferred from the length of branches if it is a list.

methodstr, optional

Only needed if index is a QuantumVariable. The method used to implement the quantum switch. Can be "auto", "sequential", "parallel", or "tree". Default is "auto". Method "tree" uses balanced binary trees. Method "parallel" is exponentially faster but requires more qubits.

Returns:
object

The return value of the respective function.

Examples

Classical index

We write a script that brings a QuantumFloat into superpostion and subsequently measures it. If the measurement result is k we add 3-k such that in the end, the float will always be in the \(\ket{\text{3}}\) state.

from qrisp import *
from qrisp.jasp import *
import jax.numpy as jnp

@jaspify
def main():

    def f0(x): x += 3
    def f1(x): x += 2
    def f2(x): x += 1
    def f3(x): pass
    branches = [f0, f1, f2, f3]

    operand = QuantumFloat(2)
    h(operand)
    index = jnp.int32(measure(operand))

    q_switch(index, branches, operand)
    return measure(operand)

print(main())
# 3.0

Quantum index

We write a script that uses a QuantumFloat as index to select different operations on another operand QuantumFloat. The index variable is put into superposition such that all branches are executed in superposition.

from qrisp import *
from qrisp.jasp import *

@terminal_sampling
def main():

    def f0(x): x += 1
    def f1(x): x += 2
    def f2(x): pass
    def f3(x): h(x[1])
    branches = [f0, f1, f2, f3]

    operand = QuantumFloat(4)
    operand[:] = 1
    index = QuantumFloat(2)
    h(index)

    q_switch(index, branches, operand)
    return index, operand

print(main())
# {(0.0, 2.0): 0.25000000372529035, (1.0, 3.0): 0.25000000372529035,
# (2.0, 1.0): 0.25000000372529035, (3.0, 1.0): 0.12499999441206447,
# (3.0, 3.0): 0.12499999441206447}