Source code for qrisp.jasp.program_control.sampling

"""
\********************************************************************************
* Copyright (c) 2025 the Qrisp authors
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* This Source Code may also be made available under the following Secondary
* Licenses when the conditions for such availability set forth in the Eclipse
* Public License, v. 2.0 are satisfied: GNU General Public License, version 2
* with the GNU Classpath Exception which is
* available at https://www.gnu.org/software/classpath/license.html.
*
* SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0
********************************************************************************/
"""

import jax
import jax.numpy as jnp

from qrisp.jasp.tracing_logic import quantum_kernel, check_for_tracing_mode

# The following function implements the sample feature.

# The basic functionality would be relatively straightforward to implement,
# however there are some complications. The reason for that is that the resulting
# jaxpr should be "readable" by the terminal sampling interpreter.
# Terminal sampling means that instead of performing the simulations "shots"-times
# it is performed once and the shots are then sampled from that distribution.
# Naturally this implies a massive performance increase, which is why a lot
# of effort is spent to realize a smooth implementation.

# The underlying idea to make the feature easily "readable" by the terminal
# sampling interpreter is to structure one iteration of sampling into three
# steps.

# 1. Evaluating the user function, which generates the distribution.
# 2. Sampling from that distribution via the "measure" function.
# 3. Decoding and postprocessing the measurement results.

# For the final two steps we deploy some custom logic to realize the terminal
# sampling behavior. To simplify the automatic processing of these steps,
# we capture each into individual pjit calls.

# The terminal sampling interpreter then identifies each steps via the
# eqn.params["name"] attribute and executes the custom logic.


[docs] def sample(state_prep = None, shots = 0, post_processor = None): r""" The ``sample`` function allows to take samples from a state that is specified by a preparation procedure. This preparation procedure can be supplied via a Python function that returns one or more :ref:`QuantumVariables <QuantumVariable>`. The samples are returned in the form of a `Jax Array <https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html>`_ which is shaped according to the ``shots`` parameter. Because of this, shots can only be a **static integer** (no dynamic values!). If you want to sample with a dynamic shot amount, look into :ref:`expectation_value`. Parameters ---------- state_prep : callable A function returning one or more :ref:`QuantumVariables <QuantumVariable>`. The state from this QuantumVariables will be sampled. The state preparation function can only take classical values as arguments. This is because a quantum value would need to be copied for each sampling iteration, which is prohibited by the no-cloning theorem. shots : int The amounts of samples to take. post_processor : callable, optional A function to apply to the samples directly after measuring. By default no post processing is applied. Raises ------ Exception Tried to sample with dynamic shots value (static integer required) Exception Tried to sample from state preparation function taking a quantum value Returns ------- callable A classical, Jax traceable function returning a jax array containing the measurement results of each shot. Examples -------- We prepare the state .. math:: \ket{\psi} = \frac{1}{\sqrt{2}} \left(\ket{0}\ket{0}\ket{\text{True}} + \ket{k}\ket{k}\ket{\text{True}})\right) :: from qrisp import * from qrisp.jasp import * def state_prep(k): a = QuantumFloat(4) b = QuantumFloat(4) qbl = QuantumBool() h(qbl) with control(qbl[0]): a[:] = k cx(a, b) return a, b And subsequently sample from the QuantumFloats: :: @jaspify def main(k): sampling_function = sample(state_prep, shots = 10) return sampling_function(k) print(main(3)) # Yields # [[3. 3.] # [0. 0.] # [0. 0.] # [3. 3.] # [0. 0.] # [0. 0.] # [3. 3.] # [3. 3.] # [0. 0.] # [0. 0.]] To demonstrate the post processing feature, we write a simple post processing function: :: def post_processor(x, y): return 2*x + y//2 @jaspify def main(k): sampling_function = sample(state_prep, shots = 10, post_processor = post_processor) return sampling_function(k) print(main(4)) # Yields # [10. 10. 0. 0. 0. 0. 0. 0. 10. 10.] """ from qrisp.jasp import qache from qrisp.core import QuantumVariable, measure if isinstance(state_prep, int): shots = state_prep state_prep = None if state_prep is None: return lambda x : sample(x, shots, post_processor = post_processor) if post_processor is None: def identity(*args): if len(args) == 1: return args[0] return args post_processor = identity if isinstance(shots, jax.core.Tracer): raise Exception("Tried to sample with dynamic shots value (static integer required)") elif not isinstance(shots, int): raise Exception(f"Tried to sample with shots value of non-integer type {type(shots)}") # Qache the user function @qache def user_func(*args): return state_prep(*args) # This function evaluates the sampling process @jax.jit def sampling_eval_function(tracerized_shots, *args): for arg in args: if isinstance(arg, QuantumVariable): raise Exception("Tried to sample from state preparation function taking a quantum value") # We now construct a loop to collect the samples by # inserting the postprocessed measurement result into an array. # The following function is the loop body, which is kernelized. @quantum_kernel def sampling_body_func(i, args): acc = args[0] # Evaluate the user function qv_tuple = user_func(*args[1:]) if not isinstance(qv_tuple, tuple): qv_tuple = (qv_tuple,) for qv in qv_tuple: if not isinstance(qv, QuantumVariable): raise Exception("Tried to sample from function not returning a QuantumVariable") # Trace the DynamicQubitArray measurements # Since we execute the measurements on the .reg attribute, no decoding # is applied. The decoding happens in sampling_helper_2 @qache def sampling_helper_1(*args): res_list = [] for reg in args: res_list.append(measure(reg)) return tuple(res_list) measurement_ints = sampling_helper_1(*[qv.reg for qv in qv_tuple]) # Trace the decoding @jax.jit def sampling_helper_2(acc, i, *meas_ints): decoded_values = [] for j in range(len(qv_tuple)): decoded_values.append(qv_tuple[j].jdecoder(meas_ints[j])) if len(qv_tuple) > 1: decoded_values = post_processor(*decoded_values) else: decoded_values = post_processor(*decoded_values) if isinstance(decoded_values, tuple): # Save the return amount (for more details check the comment of the) # initialization command of return_amount return_amount.append(len(decoded_values)) if len(acc.shape) == 1: raise AuxException() # Insert into the accumulating array acc = acc.at[i].set(decoded_values) return acc acc = sampling_helper_2(acc, i, *measurement_ints) return (acc, *args[1:]) # This list captures the amount of return values. The strategy here is # to initially assume only one QuantumVariable is returned, which is then # added to the expectation value accumulator. If more than one is returned, # the amount is saved in this list and an exception is raised, which # subsequently causes another call but this time with the correct accumulator # dimension. return_amount = [] try: loop_res = jax.lax.fori_loop(0, tracerized_shots, sampling_body_func, (jnp.zeros(shots), *args)) return loop_res[0] except AuxException: loop_res = jax.lax.fori_loop(0, tracerized_shots, sampling_body_func, (jnp.zeros((shots, return_amount[0])), *args)) return loop_res[0] from qrisp.jasp import terminal_sampling def return_function(*args): if check_for_tracing_mode(): return sampling_eval_function(shots, *args) else: return terminal_sampling(state_prep, shots)(*args) return return_function
class AuxException(Exception): pass