Quantum Switch Case#
- q_switch(index, branches, *operands, branch_amount=None, method='auto')[source]
Classical index
Jasp compatible version of jax.lax.switch. The parameters and semantics are the same as for the Jax version.
Performs the following semantics:
def q_switch(index, branches, *operands): return branches[index](*operands)
Quantum index
Executes a quantum switch - case statement distinguishing between given in-place functions.
Implements the operation
\[\text{SELECT} = \sum_i \ket{i}\bra{i} \otimes U_i\]for unitaries (branches) \(U_i\), applying the \(i\)-th unitary conditioned on the index variable being in state \(\ket{i}\).
- Parameters:
- indexint or jax.core.Tracer or QuantumVariable or list[Qubit]
An integer value, deciding which function gets executed.
- brancheslist[callable] or callable
List of functions to be executed based on
indexor a single function that takes the index as first argument.- *operandstuple
The input values for whichever function is applied.
- branch_amountint, optional
The amount of branches. Only needed if
indexis a QuantumVariable andbranchesis a function. Is automatically inferred from the length ofbranchesif it is a list.- methodstr, optional
Only needed if
indexis a QuantumVariable. The method used to implement the quantum switch. Can be"auto","sequential","parallel", or"tree". Default is"auto". Method"tree"uses balanced binary trees. Method"parallel"is exponentially faster but requires more qubits.
- Returns:
- object
The return value of the respective function.
Examples
Classical index
We write a script that brings a QuantumFloat into superpostion and subsequently measures it. If the measurement result is
kwe add3-ksuch that in the end, the float will always be in the \(\ket{\text{3}}\) state.from qrisp import * from qrisp.jasp import * import jax.numpy as jnp @jaspify def main(): def f0(x): x += 3 def f1(x): x += 2 def f2(x): x += 1 def f3(x): pass branches = [f0, f1, f2, f3] operand = QuantumFloat(2) h(operand) index = jnp.int32(measure(operand)) q_switch(index, branches, operand) return measure(operand) print(main()) # 3.0
Quantum index
We write a script that uses a QuantumFloat as index to select different operations on another operand QuantumFloat. The index variable is put into superposition such that all branches are executed in superposition.
from qrisp import * from qrisp.jasp import * @terminal_sampling def main(): def f0(x): x += 1 def f1(x): x += 2 def f2(x): pass def f3(x): h(x[1]) branches = [f0, f1, f2, f3] operand = QuantumFloat(4) operand[:] = 1 index = QuantumFloat(2) h(index) q_switch(index, branches, operand) return index, operand print(main()) # {(0.0, 2.0): 0.25000000372529035, (1.0, 3.0): 0.25000000372529035, # (2.0, 1.0): 0.25000000372529035, (3.0, 1.0): 0.12499999441206447, # (3.0, 3.0): 0.12499999441206447}