Source code for qrisp.jasp.tracing_logic.quantum_kernel

"""
\********************************************************************************
* Copyright (c) 2025 the Qrisp authors
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* This Source Code may also be made available under the following Secondary
* Licenses when the conditions for such availability set forth in the Eclipse
* Public License, v. 2.0 are satisfied: GNU General Public License, version 2
* with the GNU Classpath Exception which is
* available at https://www.gnu.org/software/classpath/license.html.
*
* SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0
********************************************************************************/
"""

import jax

from qrisp.jasp.primitives import quantum_kernel_p
from qrisp.jasp.tracing_logic import TracingQuantumSession, qache

[docs] def quantum_kernel(func): """ This decorator allows you to annotate a subroutine as a "quantum kernel". Quantum kernels are functions that are restricted in the sense that they can not have quantum inputs or outputs, yet their inner working can be quantum. What is the benefit in that? The underlying idea why this can be helpful is that future execution environments might host several QPUs that can operate in parallel, much like many of todays HPC environments can access multiple GPUs. Annotating a function as a quantum kernel therefore allows the execution environment to identify the subroutine as a separate quantum state and can assign it to a dedicated QPU. .. note:: While not many quantum algorithms exist that directly allow such a parallelization, any sampling task can be performed in this manner: If you need to execute a 1000 shots of a certain quantum circuit, but you have 4 QPUs available, you can execute the task 4 times faster by assigning 250 shots to each QPU. As such the :ref:`sample <sample>` and :ref:`expectation_value <expectation_value>` function automatically wraps the state preparation and measurement into a dedicated quantum kernel. Parameters ---------- func : callable A function that receives only classical values as inputs and returns classical values as outpus. The function's body can however perform quantum logic. Returns ------- quantum_kernel : callable A function that performs the task of the input but the compiler can identify it as a closed quantum procedure without any external entaglement. Examples -------- We demonstrate a naive implementation of an expectation value please use :ref:`expectation_value` if you required this functionality. For this we define a state preparation procedure and call it from a kernelized function. :: from qrisp import * from qrisp.jasp import * def state_prep(k): qf = QuantumFloat(5) h(qf[k]) return qf @quantum_kernel def sampling_kernel(k): # Receives a classical (!) integer k qf = state_prep(k) # Returns a classical integer return measure(qf) We now call the kernel within a purely classical Jax script. :: @jaspify def main(k): shots = 100 res = 0 for i in range(shots): res += sampling_kernel(k) return res/shots Perform some experiments: :: print(main(3)) # Yields: 3.92 # Expected: 2**3/2 = 4 print(main(4)) # Yields: 8.96 # Expected: 2**4/2 = 8 """ func = qache(func) def return_function(*args, **kwargs): from qrisp.jasp.jasp_expression.centerclass import Jaspr, collect_environments qs = TracingQuantumSession.get_instance() qs.start_tracing(quantum_kernel_p.bind()) try: res = func(*args, **kwargs) except Exception as e: qs.conclude_tracing() raise e eqn = jax._src.core.thread_local_state.trace_state.trace_stack.dynamic.jaxpr_stack[0].eqns[-1] flattened_jaspr = Jaspr.from_cache(collect_environments(eqn.params["jaxpr"].jaxpr)).flatten_environments() eqn.params["jaxpr"] = jax.core.ClosedJaxpr(flattened_jaspr, eqn.params["jaxpr"].consts) qs.conclude_tracing() return res return return_function