Source code for qrisp.jasp.jasp_expression.centerclass

"""
\********************************************************************************
* Copyright (c) 2025 the Qrisp authors
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* This Source Code may also be made available under the following Secondary
* Licenses when the conditions for such availability set forth in the Eclipse
* Public License, v. 2.0 are satisfied: GNU General Public License, version 2
* with the GNU Classpath Exception which is
* available at https://www.gnu.org/software/classpath/license.html.
*
* SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0
********************************************************************************/
"""
from functools import lru_cache

import jax
from jax import make_jaxpr
from jax.core import Jaxpr, Literal
from jax.tree_util import tree_flatten, tree_unflatten


from qrisp.jasp.jasp_expression import invert_jaspr, collect_environments
from qrisp.jasp import eval_jaxpr, pjit_to_gate, flatten_environments, cond_to_cl_control
from qrisp.jasp.primitives import AbstractQuantumCircuit


[docs] class Jaspr(Jaxpr): """ The ``Jaspr`` class enables an efficient representations of a wide variety of (hybrid) algorithms. For many applications, the representation is agnostic to the scale of the problem, implying function calls with 10 or 10000 qubits can be represented by the same object. The actual unfolding to a circuit-level description is outsourced to `established, classical compilation infrastructure <https://mlir.llvm.org/>`_, implying state-of-the-art compilation speed can be reached. As a subtype of ``jax.core.Jaxpr``, Jasprs are embedded into the well matured `Jax ecosystem <https://github.com/n2cholas/awesome-jax>`_, which facilitates the compilation of classical `real-time computation <https://arxiv.org/abs/2206.12950>`_ using some of the most advanced libraries in the world such as `CUDA <https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html>`_. Especially `machine learning <https://ai.google.dev/gemma/docs/jax_inference>`_ and other scientific computations tasks are particularly well supported. To get a better understanding of the syntax and semantics of Jaxpr (and with that also Jaspr) please check `this link <https://jax.readthedocs.io/en/latest/jaxpr.html>`__. Similar to Jaxpr, Jaspr objects represent (hybrid) quantum algorithms in the form of a `functional programming language <https://en.wikipedia.org/wiki/Functional_programming>`_ in `SSA-form <https://en.wikipedia.org/wiki/Static_single-assignment_form>`_. It is possible to compile Jaspr objects into QIR, which is facilitated by the `Catalyst framework <https://docs.pennylane.ai/projects/catalyst/en/stable/index.html>`__ (check :meth:`qrisp.jasp.jaspr.to_qir` for more details). Qrisp scripts can be turned into Jaspr objects by calling the ``make_jaspr`` function, which has similar semantics as `jax.make_jaxpr <https://jax.readthedocs.io/en/latest/_autosummary/jax.make_jaxpr.html>`_. :: from qrisp import * from qrisp.jasp import make_jaspr def test_fun(i): qv = QuantumFloat(i, -1) x(qv[0]) cx(qv[0], qv[i-1]) meas_res = measure(qv) meas_res += 1 return meas_res jaspr = make_jaspr(test_fun)(4) print(jaspr) This will give you the following output: .. code-block:: { lambda ; a:QuantumCircuit b:i32[]. let c:QuantumCircuit d:QubitArray = create_qubits a b e:Qubit = get_qubit d 0 f:QuantumCircuit = x c e g:i32[] = sub b 1 h:Qubit = get_qubit d g i:QuantumCircuit = cx f e h j:QuantumCircuit k:i32[] = measure i d l:f32[] = convert_element_type[new_dtype=float64 weak_type=True] k m:f32[] = mul l 0.5 n:f32[] = add m 1.0 in (j, n) } A defining feature of the Jaspr class is that the first input and the first output are always of QuantumCircuit type. Therefore, Jaspr objects always represent some (hybrid) quantum operation. Qrisp comes with a built-in Jaspr interpreter. For that you simply have to call the object like a function: >>> print(jaspr(2)) 2.5 >>> print(jaspr(4)) 5.5 """ __slots__ = "permeability", "isqfree", "hashvalue", "ctrl_jaspr", "envs_flattened", "consts" def __init__(self, *args, permeability = None, isqfree = None, ctrl_jaspr = None, **kwargs): if len(args) == 1: kwargs["jaxpr"] = args[0] if "jaxpr" in kwargs: jaxpr = kwargs["jaxpr"] self.hashvalue = hash(jaxpr) Jaxpr.__init__(self, constvars = jaxpr.constvars, invars = jaxpr.invars, outvars = jaxpr.outvars, eqns = jaxpr.eqns, effects = jaxpr.effects, debug_info = jaxpr.debug_info ) else: self.hashvalue = id(self) Jaxpr.__init__(self, **kwargs) self.permeability = {} if permeability is None: permeability = {} for var in self.constvars + self.invars + self.outvars: if isinstance(var, Literal): continue self.permeability[var] = permeability.get(var, None) self.isqfree = isqfree self.ctrl_jaspr = ctrl_jaspr self.envs_flattened = False self.consts = [] if not isinstance(self.invars[0].aval, AbstractQuantumCircuit): raise Exception(f"Tried to create a Jaspr from data that doesn't have a QuantumCircuit as first argument (got {type(self.invars[0].aval)} instead)") if not isinstance(self.outvars[0].aval, AbstractQuantumCircuit): raise Exception(f"Tried to create a Jaspr from data that doesn't have a QuantumCircuit as first entry of return type (got {type(self.outvars[0].aval)} instead)") def __hash__(self): return self.hashvalue def __eq__(self, other): if not isinstance(other, Jaxpr): return False return id(self) == id(other) def copy(self): if self.ctrl_jaspr is None: ctrl_jaspr = None else: ctrl_jaspr = self.ctrl_jaspr.copy() res = Jaspr(permeability = self.permeability, isqfree = self.isqfree, ctrl_jaspr = ctrl_jaspr, constvars = list(self.constvars), invars = list(self.invars), outvars = list(self.outvars), eqns = list(self.eqns), effects = self.effects, debug_info = self.debug_info) res.envs_flattened = self.envs_flattened return res
[docs] def inverse(self): """ Returns the inverse Jaspr (if applicable). For Jaspr that contain realtime computations or measurements, the inverse does not exist. Returns ------- Jaspr The daggered Jaspr. Examples -------- We create a simple script and inspect the daggered version: :: from qrisp import * from qrisp.jasp import make_jaspr def example_function(i): qv = QuantumVariable(i) cx(qv[0], qv[1]) t(qv[1]) return qv jaspr = make_jaspr(example_function)(2) print(jaspr.inverse()) # Yields # { lambda ; a:QuantumCircuit b:i32[]. let # c:QuantumCircuit d:QubitArray = create_qubits a b # e:Qubit = get_qubit d 0 # f:Qubit = get_qubit d 1 # g:QuantumCircuit = t_dg c f # h:QuantumCircuit = cx g e f # in (h, d) } """ return invert_jaspr(self)
[docs] def control(self, num_ctrl, ctrl_state = -1): """ Returns the controlled version of the Jaspr. The control qubits are added to the signature of the Jaspr as the arguments after the QuantumCircuit. Parameters ---------- num_ctrl : int The amount of controls to be added. ctrl_state : int of str, optional The control state on which to activate. The default is -1. Returns ------- Jaspr The controlled Jaspr. Examples -------- We create a simple script and inspect the controlled version: :: from qrisp import * from qrisp.jasp import make_jaspr def example_function(i): qv = QuantumVariable(i) cx(qv[0], qv[1]) t(qv[1]) return qv jaspr = make_jaspr(example_function)(2) print(jaspr.control(2)) # Yields # { lambda ; a:QuantumCircuit b:Qubit c:Qubit d:i32[]. let # e:QuantumCircuit f:QubitArray = create_qubits a 1 # g:Qubit = get_qubit f 0 # h:QuantumCircuit = ccx e b c g # i:QuantumCircuit j:QubitArray = create_qubits h d # k:Qubit = get_qubit j 0 # l:Qubit = get_qubit j 1 # m:QuantumCircuit = ccx i g k l # n:QuantumCircuit = ct m g l # o:QuantumCircuit = ccx n b c g # in (o, j) } We see that the control qubits are part of the function signature (``a`` and ``b``) """ if self.ctrl_jaspr is not None and num_ctrl == 1 and ctrl_state == -1: return self.ctrl_jaspr from qrisp.jasp import ControlledJaspr if isinstance(ctrl_state, int): if ctrl_state < 0: ctrl_state += 2**num_ctrl ctrl_state = bin(ctrl_state)[2:].zfill(num_ctrl) else: ctrl_state = str(ctrl_state) return ControlledJaspr.from_cache(self, ctrl_state)
[docs] def to_qc(self, *args): """ Converts the Jaspr into a :ref:`QuantumCircuit` if applicable. Circuit conversion of algorithms involving realtime computations is not possible. Parameters ---------- *args : tuple The arguments to call the Jaspr with. Returns ------- :ref:`QuantumCircuit` The resulting QuantumCircuit. return_values : tuple The return values of the Jaspr. QuantumVariable return types are returned as lists of Qubits. Examples -------- We create a simple script and inspect the QuantumCircuit: :: from qrisp import * from qrisp.jasp import make_jaspr def example_function(i): qv = QuantumVariable(i) cx(qv[0], qv[1]) t(qv[1]) return qv jaspr = make_jaspr(example_function)(2) qc, qb_list = jaspr.to_qc(2) print(qc) # Yields # qb_0: ──■─────── # ┌─┴─┐┌───┐ # qb_1: ┤ X ├┤ T ├ # └───┘└───┘ """ from qrisp import QuantumCircuit, Clbit jaspr = self def eqn_evaluator(eqn, context_dic): if eqn.primitive.name == "pjit" and isinstance(eqn.params["jaxpr"].jaxpr, Jaspr): return pjit_to_gate(eqn, context_dic, eqn_evaluator) elif eqn.primitive.name == "cond": return cond_to_cl_control(eqn, context_dic, eqn_evaluator) elif eqn.primitive.name == "convert_element_type": if isinstance(context_dic[eqn.invars[0]], Clbit): context_dic[eqn.outvars[0]] = context_dic[eqn.invars[0]] return return True res = eval_jaxpr(jaspr, eqn_evaluator = eqn_evaluator)(*([QuantumCircuit()] + list(args))) return res
def eval(self, *args, eqn_evaluator = lambda x, y : True): return eval_jaxpr(self, eqn_evaluator = eqn_evaluator)(*args) def flatten_environments(self): """ Flattens all environments by applying the corresponding compilation routines such that no more ``q_env`` primitives are left. Returns ------- Jaspr The Jaspr with flattened environments. Examples -------- We create a Jaspr containing an :ref:`InversionEnvironment` and flatten: :: def test_function(i): qv = QuantumVariable(i) with invert(): t(qv[0]) cx(qv[0], qv[1]) return qv jaspr = make_jaspr(test_function)(2) print(jaspr) :: { lambda ; a:QuantumCircuit b:i32[]. let c:QuantumCircuit d:QubitArray = create_qubits a b e:QuantumCircuit = q_env[ jaspr={ lambda ; f:QuantumCircuit d:QubitArray. let g:Qubit = get_qubit d 0 h:QuantumCircuit = t f g i:Qubit = get_qubit d 1 j:QuantumCircuit = cx h g i in (j,) } type=InversionEnvironment ] c d in (e, d) } You can see how the body of the :ref:`InversionEnvironment` is __collected__ into another Jaspr. This reflects the fact that at their core, :ref:`QuantumEnvironment <QuantumEnvironment>` describe `higher-order quantum functions <https://en.wikipedia.org/wiki/Higher-order_function>`_ (ie. functions that operate on functions). In order to apply the transformations induced by the QuantumEnvironment, we can call ``jaspr.flatten_environments``: >>> print(jaspr.flatten_environments) { lambda ; a:QuantumCircuit b:i32[]. let c:QuantumCircuit d:QubitArray = create_qubits a b e:Qubit = get_qubit d 0 f:Qubit = get_qubit d 1 g:QuantumCircuit = cx c e f h:QuantumCircuit = t_dg g e in (h, d) } We see that as expected, the order of the ``cx`` and the ``t`` gate has been switched and the ``t`` gate has been turned into a ``t_dg``. """ res = flatten_environments(self) if self.ctrl_jaspr is not None: res.ctrl_jaspr = self.ctrl_jaspr.flatten_environments() return res def __call__(self, *args): from qrisp.jasp.evaluation_tools.jaspification import simulate_jaspr return simulate_jaspr(self, *args) if len(self.outvars) == 1: return None from jax.tree_util import tree_flatten from qrisp.simulator import BufferedQuantumState args = [BufferedQuantumState()] + list(tree_flatten(args)[0]) from qrisp.jasp import extract_invalues, insert_outvalues, eval_jaxpr flattened_jaspr = self def eqn_evaluator(eqn, context_dic): if eqn.primitive.name == "pjit": if eqn.params["name"] == "expectation_value_eval_function": from qrisp.jasp.program_control import sampling_evaluator sampling_evaluator("ev")(eqn, context_dic, eqn_evaluator = eqn_evaluator) return if eqn.params["name"] == "sampling_eval_function": from qrisp.jasp.program_control import sampling_evaluator sampling_evaluator("array")(eqn, context_dic, eqn_evaluator = eqn_evaluator) return invalues = extract_invalues(eqn, context_dic) outvalues = eval_jaxpr(eqn.params["jaxpr"], eqn_evaluator = eqn_evaluator)(*invalues) if not isinstance(outvalues, (list, tuple)): outvalues = [outvalues] insert_outvalues(eqn, context_dic, outvalues) elif eqn.primitive.name == "jasp.quantum_kernel": insert_outvalues(eqn, context_dic, BufferedQuantumState()) else: return True res = eval_jaxpr(flattened_jaspr, eqn_evaluator = eqn_evaluator)(*(args + self.consts)) if len(self.outvars) == 2: return res[1] else: return res[1:] def inline(self, *args): from qrisp.jasp import TracingQuantumSession qs = TracingQuantumSession.get_instance() abs_qc = qs.abs_qc res = eval_jaxpr(self)(*([abs_qc] + list(args))) if isinstance(res, tuple): new_abs_qc = res[0] res = res[1:] else: new_abs_qc = res res = None qs.abs_qc = new_abs_qc return res def embedd(self, *args, name = None, inline = False): from qrisp.jasp import TracingQuantumSession qs = TracingQuantumSession.get_instance() abs_qc = qs.abs_qc if not inline: res = jax.jit(eval_jaxpr(self))(*([abs_qc] + list(args))) eqn = jax._src.core.thread_local_state.trace_state.trace_stack.dynamic.jaxpr_stack[0].eqns[-1] eqn.params["jaxpr"] = jax.core.ClosedJaxpr(self, eqn.params["jaxpr"].consts) if name is not None: eqn.params["name"] = name else: res = eval_jaxpr(self)(*([abs_qc] + list(args))) if isinstance(res, tuple): new_abs_qc = res[0] res = res[1:] else: new_abs_qc = res res = None qs.abs_qc = new_abs_qc return res
[docs] def qjit(self, *args, function_name = "jaspr_function"): """ Leverages the Catalyst pipeline to compile a QIR representation of this function and executes that function using the Catalyst QIR runtime. Parameters ---------- *args : iterable The arguments to call the function with. Returns ------- The values returned by the compiled, executed function. """ flattened_jaspr = self from qrisp.jasp.evaluation_tools.catalyst_interface import jaspr_to_catalyst_qjit qjit_obj = jaspr_to_catalyst_qjit(flattened_jaspr, function_name = function_name) res = qjit_obj.compiled_function(*args) if not isinstance(res, (tuple,list)): return res elif len(res) == 1: return res[0] else: return res
@classmethod @lru_cache(maxsize = int(1E5)) def from_cache(cls, jaxpr): return Jaspr(jaxpr = jaxpr) def update_eqns(self, eqns): return Jaspr(constvars = list(self.constvars), invars = list(self.invars), outvars = list(self.outvars), eqns = list(eqns))
[docs] def to_qir(self): """ Compiles the Jaspr to QIR using the `Catalyst framework <https://docs.pennylane.ai/projects/catalyst/en/stable/index.html>`__. Parameters ---------- None Returns ------- str The QIR string. Examples -------- We create a simple script and inspect the QIR string: :: from qrisp import * from qrisp.jasp import make_jaspr def example_function(i): qv = QuantumFloat(i) cx(qv[0], qv[1]) t(qv[1]) meas_res = measure(qv) meas_res += 1 return meas_res jaspr = make_jaspr(example_function)(2) print(jaspr.to_qir()) Yields: :: ; ModuleID = 'LLVMDialectModule' source_filename = "LLVMDialectModule" @"{'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None}" = internal constant [66 x i8] c"{'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None}\00" @lightning.qubit = internal constant [16 x i8] c"lightning.qubit\00" @"/home/positr0nium/miniconda3/envs/qrisp/lib/python3.10/site-packages/catalyst/utils/../lib/librtd_lightning.so" = internal constant [111 x i8] c"/home/positr0nium/miniconda3/envs/qrisp/lib/python3.10/site-packages/catalyst/utils/../lib/librtd_lightning.so\00" declare void @__catalyst__rt__finalize() local_unnamed_addr declare void @__catalyst__rt__initialize() local_unnamed_addr declare ptr @__catalyst__qis__Measure(ptr, i32) local_unnamed_addr declare void @__catalyst__qis__T(ptr, ptr) local_unnamed_addr declare void @__catalyst__qis__CNOT(ptr, ptr, ptr) local_unnamed_addr declare ptr @__catalyst__rt__array_get_element_ptr_1d(ptr, i64) local_unnamed_addr declare ptr @__catalyst__rt__qubit_allocate_array(i64) local_unnamed_addr declare void @__catalyst__rt__device_init(ptr, ptr, ptr) local_unnamed_addr declare void @_mlir_memref_to_llvm_free(ptr) local_unnamed_addr declare ptr @_mlir_memref_to_llvm_alloc(i64) local_unnamed_addr define { ptr, ptr, i64 } @jit_jaspr_function(ptr nocapture readnone %0, ptr nocapture readonly %1, i64 %2) local_unnamed_addr { tail call void @__catalyst__rt__device_init(ptr nonnull @"/home/positr0nium/miniconda3/envs/qrisp/lib/python3.10/site-packages/catalyst/utils/../lib/librtd_lightning.so", ptr nonnull @lightning.qubit, ptr nonnull @"{'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None}") %4 = tail call ptr @__catalyst__rt__qubit_allocate_array(i64 20) %5 = tail call ptr @__catalyst__rt__array_get_element_ptr_1d(ptr %4, i64 0) %6 = load ptr, ptr %5, align 8 %7 = tail call ptr @__catalyst__rt__array_get_element_ptr_1d(ptr %4, i64 1) %8 = load ptr, ptr %7, align 8 tail call void @__catalyst__qis__CNOT(ptr %6, ptr %8, ptr null) %9 = tail call ptr @__catalyst__rt__array_get_element_ptr_1d(ptr %4, i64 1) %10 = load ptr, ptr %9, align 8 tail call void @__catalyst__qis__T(ptr %10, ptr null) %11 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 65) %12 = ptrtoint ptr %11 to i64 %13 = add i64 %12, 63 %14 = and i64 %13, -64 %15 = inttoptr i64 %14 to ptr %16 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 65) %17 = ptrtoint ptr %16 to i64 %18 = add i64 %17, 63 %19 = and i64 %18, -64 %20 = inttoptr i64 %19 to ptr %21 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 72) %22 = ptrtoint ptr %21 to i64 %23 = add i64 %22, 63 %24 = and i64 %23, -64 %25 = inttoptr i64 %24 to ptr %26 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 72) %27 = ptrtoint ptr %26 to i64 %28 = add i64 %27, 63 %29 = and i64 %28, -64 %30 = inttoptr i64 %29 to ptr %31 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 72) %32 = ptrtoint ptr %31 to i64 %33 = add i64 %32, 63 %34 = and i64 %33, -64 %35 = inttoptr i64 %34 to ptr %36 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 72) %37 = ptrtoint ptr %36 to i64 %38 = add i64 %37, 63 %39 = and i64 %38, -64 %40 = inttoptr i64 %39 to ptr %41 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 8) store i64 0, ptr %41, align 1 %42 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 8) store i64 0, ptr %42, align 1 %43 = load i64, ptr %1, align 4 %44 = icmp slt i64 %43, 1 store i1 %44, ptr %15, align 64 %45 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 8) %46 = load i64, ptr %42, align 1 store i64 %46, ptr %45, align 1 tail call void @_mlir_memref_to_llvm_free(ptr nonnull %42) %47 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 8) %48 = load i64, ptr %41, align 1 store i64 %48, ptr %47, align 1 tail call void @_mlir_memref_to_llvm_free(ptr nonnull %41) br i1 %44, label %.lr.ph, label %._crit_edge .lr.ph: ; preds = %3, %.lr.ph %49 = phi ptr [ %87, %.lr.ph ], [ %47, %3 ] %50 = phi ptr [ %85, %.lr.ph ], [ %45, %3 ] %51 = load i64, ptr %50, align 4 %52 = tail call ptr @__catalyst__rt__array_get_element_ptr_1d(ptr %4, i64 %51) %53 = load ptr, ptr %52, align 8 %54 = tail call ptr @__catalyst__qis__Measure(ptr %53, i32 -1) %55 = load i1, ptr %54, align 1 store i1 %55, ptr %20, align 64 %56 = load i64, ptr %50, align 4 store i64 %56, ptr %25, align 64 %57 = shl i64 2, %56 %58 = icmp ult i64 %56, 64 %59 = select i1 %58, i64 %57, i64 0 store i64 %59, ptr %30, align 64 %60 = load i1, ptr %20, align 64 %61 = zext i1 %60 to i64 store i64 %61, ptr %35, align 64 %62 = load i64, ptr %30, align 64 %63 = select i1 %60, i64 %62, i64 0 store i64 %63, ptr %40, align 64 %64 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 72) %65 = ptrtoint ptr %64 to i64 %66 = add i64 %65, 63 %67 = and i64 %66, -64 %68 = inttoptr i64 %67 to ptr %69 = load i64, ptr %49, align 4 %70 = load i64, ptr %40, align 64 %71 = add i64 %70, %69 store i64 %71, ptr %68, align 64 tail call void @_mlir_memref_to_llvm_free(ptr nonnull %49) %72 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 72) %73 = ptrtoint ptr %72 to i64 %74 = add i64 %73, 63 %75 = and i64 %74, -64 %76 = inttoptr i64 %75 to ptr %77 = load i64, ptr %50, align 4 %78 = add i64 %77, 1 store i64 %78, ptr %76, align 64 tail call void @_mlir_memref_to_llvm_free(ptr nonnull %50) %79 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 8) %80 = load i64, ptr %68, align 64 store i64 %80, ptr %79, align 1 tail call void @_mlir_memref_to_llvm_free(ptr %64) %81 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 8) %82 = load i64, ptr %76, align 64 store i64 %82, ptr %81, align 1 tail call void @_mlir_memref_to_llvm_free(ptr %72) %.pre = load i64, ptr %81, align 4 %83 = load i64, ptr %1, align 4 %84 = icmp sge i64 %.pre, %83 store i1 %84, ptr %15, align 64 %85 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 8) %86 = load i64, ptr %81, align 1 store i64 %86, ptr %85, align 1 tail call void @_mlir_memref_to_llvm_free(ptr nonnull %81) %87 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 8) %88 = load i64, ptr %79, align 1 store i64 %88, ptr %87, align 1 tail call void @_mlir_memref_to_llvm_free(ptr nonnull %79) br i1 %84, label %.lr.ph, label %._crit_edge ._crit_edge: ; preds = %.lr.ph, %3 %.lcssa20 = phi ptr [ %45, %3 ], [ %85, %.lr.ph ] %.lcssa = phi ptr [ %47, %3 ], [ %87, %.lr.ph ] tail call void @_mlir_memref_to_llvm_free(ptr nonnull %.lcssa20) tail call void @_mlir_memref_to_llvm_free(ptr %36) tail call void @_mlir_memref_to_llvm_free(ptr %31) tail call void @_mlir_memref_to_llvm_free(ptr %26) tail call void @_mlir_memref_to_llvm_free(ptr %21) tail call void @_mlir_memref_to_llvm_free(ptr %16) tail call void @_mlir_memref_to_llvm_free(ptr %11) %89 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 68) %90 = ptrtoint ptr %89 to i64 %91 = add i64 %90, 63 %92 = and i64 %91, -64 %93 = inttoptr i64 %92 to ptr %94 = load i64, ptr %.lcssa, align 4 %95 = trunc i64 %94 to i32 store i32 %95, ptr %93, align 64 tail call void @_mlir_memref_to_llvm_free(ptr nonnull %.lcssa) %96 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 68) %97 = ptrtoint ptr %96 to i64 %98 = add i64 %97, 63 %99 = and i64 %98, -64 %100 = inttoptr i64 %99 to ptr %101 = load i32, ptr %93, align 64 %102 = add i32 %101, 1 store i32 %102, ptr %100, align 64 tail call void @_mlir_memref_to_llvm_free(ptr %89) %103 = icmp eq ptr %96, inttoptr (i64 3735928559 to ptr) br i1 %103, label %104, label %107 104: ; preds = %._crit_edge %105 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 4) %106 = load i32, ptr %100, align 64 store i32 %106, ptr %105, align 1 br label %107 107: ; preds = %104, %._crit_edge %.pn16 = phi ptr [ %105, %104 ], [ %96, %._crit_edge ] %.pn14 = phi ptr [ %105, %104 ], [ %100, %._crit_edge ] %.pn13 = insertvalue { ptr, ptr, i64 } undef, ptr %.pn16, 0 %.pn = insertvalue { ptr, ptr, i64 } %.pn13, ptr %.pn14, 1 %108 = insertvalue { ptr, ptr, i64 } %.pn, i64 0, 2 ret { ptr, ptr, i64 } %108 } define void @_catalyst_pyface_jit_jaspr_function(ptr nocapture writeonly %0, ptr nocapture readonly %1) local_unnamed_addr { %.unpack = load ptr, ptr %1, align 8 %.elt1.i = getelementptr inbounds { ptr, ptr, i64 }, ptr %.unpack, i64 0, i32 1 %.unpack2.i = load ptr, ptr %.elt1.i, align 8 %3 = tail call { ptr, ptr, i64 } @jit_jaspr_function(ptr poison, ptr %.unpack2.i, i64 poison) %.elt.i = extractvalue { ptr, ptr, i64 } %3, 0 store ptr %.elt.i, ptr %0, align 8 %.repack5.i = getelementptr inbounds { ptr, ptr, i64 }, ptr %0, i64 0, i32 1 %.elt6.i = extractvalue { ptr, ptr, i64 } %3, 1 store ptr %.elt6.i, ptr %.repack5.i, align 8 %.repack7.i = getelementptr inbounds { ptr, ptr, i64 }, ptr %0, i64 0, i32 2 %.elt8.i = extractvalue { ptr, ptr, i64 } %3, 2 store i64 %.elt8.i, ptr %.repack7.i, align 8 ret void } define void @_catalyst_ciface_jit_jaspr_function(ptr nocapture writeonly %0, ptr nocapture readonly %1) local_unnamed_addr { %.elt1 = getelementptr inbounds { ptr, ptr, i64 }, ptr %1, i64 0, i32 1 %.unpack2 = load ptr, ptr %.elt1, align 8 %3 = tail call { ptr, ptr, i64 } @jit_jaspr_function(ptr poison, ptr %.unpack2, i64 poison) %.elt = extractvalue { ptr, ptr, i64 } %3, 0 store ptr %.elt, ptr %0, align 8 %.repack5 = getelementptr inbounds { ptr, ptr, i64 }, ptr %0, i64 0, i32 1 %.elt6 = extractvalue { ptr, ptr, i64 } %3, 1 store ptr %.elt6, ptr %.repack5, align 8 %.repack7 = getelementptr inbounds { ptr, ptr, i64 }, ptr %0, i64 0, i32 2 %.elt8 = extractvalue { ptr, ptr, i64 } %3, 2 store i64 %.elt8, ptr %.repack7, align 8 ret void } define void @setup() local_unnamed_addr { tail call void @__catalyst__rt__initialize() ret void } define void @teardown() local_unnamed_addr { tail call void @__catalyst__rt__finalize() ret void } !llvm.module.flags = !{!0} !0 = !{i32 2, !"Debug Info Version", i32 3} """ from qrisp.jasp.evaluation_tools.catalyst_interface import jaspr_to_qir return jaspr_to_qir(self.flatten_environments())
[docs] def to_mlir(self): """ Compiles the Jaspr to MLIR using the `Catalyst dialect <https://docs.pennylane.ai/projects/catalyst/en/stable/index.html>`__. Parameters ---------- None Returns ------- str The MLIR string. Examples -------- We create a simple script and inspect the MLIR string: :: from qrisp import * from qrisp.jasp import make_jaspr def example_function(i): qv = QuantumFloat(i) cx(qv[0], qv[1]) t(qv[1]) meas_res = measure(qv) meas_res += 1 return meas_res jaspr = make_jaspr(example_function)(2) print(jaspr.to_mlir()) :: module @jaspr_function { func.func public @jit_jaspr_function(%arg0: tensor<i64>) -> tensor<i32> attributes {llvm.emit_c_interface} { %0 = stablehlo.constant dense<1> : tensor<i32> %1 = stablehlo.constant dense<2> : tensor<i64> %2 = stablehlo.constant dense<1> : tensor<i64> %3 = stablehlo.constant dense<0> : tensor<i64> quantum.device["/home/positr0nium/miniconda3/envs/qrisp/lib/python3.10/site-packages/catalyst/utils/../lib/librtd_lightning.so", "lightning.qubit", "{'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"] %4 = quantum.alloc( 20) : !quantum.reg %5 = quantum.extract %4[ 0] : !quantum.reg -> !quantum.bit %6 = quantum.extract %4[ 1] : !quantum.reg -> !quantum.bit %out_qubits:2 = quantum.custom "CNOT"() %5, %6 : !quantum.bit, !quantum.bit %7 = quantum.insert %4[ 0], %out_qubits#0 : !quantum.reg, !quantum.bit %8 = quantum.insert %7[ 1], %out_qubits#1 : !quantum.reg, !quantum.bit %9 = quantum.extract %8[ 1] : !quantum.reg -> !quantum.bit %out_qubits_0 = quantum.custom "T"() %9 : !quantum.bit %10 = quantum.insert %8[ 1], %out_qubits_0 : !quantum.reg, !quantum.bit %11 = stablehlo.add %3, %arg0 : tensor<i64> %12:3 = scf.while (%arg1 = %3, %arg2 = %3, %arg3 = %10) : (tensor<i64>, tensor<i64>, !quantum.reg) -> (tensor<i64>, tensor<i64>, !quantum.reg) { %16 = stablehlo.compare GE, %arg1, %11, SIGNED : (tensor<i64>, tensor<i64>) -> tensor<i1> %extracted = tensor.extract %16[] : tensor<i1> scf.condition(%extracted) %arg1, %arg2, %arg3 : tensor<i64>, tensor<i64>, !quantum.reg } do { ^bb0(%arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: !quantum.reg): %extracted = tensor.extract %arg1[] : tensor<i64> %16 = quantum.extract %arg3[%extracted] : !quantum.reg -> !quantum.bit %mres, %out_qubit = quantum.measure %16 : i1, !quantum.bit %from_elements = tensor.from_elements %mres : tensor<i1> %extracted_1 = tensor.extract %arg1[] : tensor<i64> %17 = quantum.insert %arg3[%extracted_1], %out_qubit : !quantum.reg, !quantum.bit %18 = stablehlo.subtract %arg1, %3 : tensor<i64> %19 = stablehlo.shift_left %1, %18 : tensor<i64> %20 = stablehlo.convert %from_elements : (tensor<i1>) -> tensor<i64> %21 = stablehlo.multiply %19, %20 : tensor<i64> %22 = stablehlo.add %arg2, %21 : tensor<i64> %23 = stablehlo.add %arg1, %2 : tensor<i64> scf.yield %23, %22, %17 : tensor<i64>, tensor<i64>, !quantum.reg } %13 = stablehlo.convert %12#1 : (tensor<i64>) -> tensor<i32> %14 = stablehlo.multiply %13, %0 : tensor<i32> %15 = stablehlo.add %14, %0 : tensor<i32> return %15 : tensor<i32> } func.func @setup() { quantum.init return } func.func @teardown() { quantum.finalize return } } """ from qrisp.jasp.evaluation_tools.catalyst_interface import jaspr_to_mlir return jaspr_to_mlir(self.flatten_environments())
[docs] def to_qasm(self, *args): """ Compiles the Jaspr into an OpenQASM 2 string. Real-time control is possible as long as no computations on the measurement results are performed. Parameters ---------- *args : list The arguments to call the :ref:`QuantumCircuit` evaluation with. Returns ------- str The OpenQASM 2 string. Examples -------- We create a simple script and inspect the QASM 2 string: :: from qrisp import * from qrisp.jasp import make_jaspr def main(i): qv = QuantumVariable(i) cx(qv[0], qv[1]) t(qv[1]) return qv jaspr = make_jaspr(main)(2) qasm_str = jaspr.to_qasm(2) print(qasm_str) # Yields # OPENQASM 2.0; # include "qelib1.inc"; # qreg qb_59[1]; # qreg qb_60[1]; # cx qb_59[0],qb_60[0]; # t qb_60[0]; It is also possible to compile simple real-time control features: :: def main(phi): qf = QuantumFloat(5) h(qf) bl = measure(qf[0]) with control(bl): rz(phi, qf[1]) x(qf[1]) return jaspr = make_jaspr(main)(0.5) print(jaspr.to_qasm(0.5)) This gives: :: OPENQASM 2.0; include "qelib1.inc"; qreg qb_59[1]; qreg qb_60[1]; qreg qb_61[1]; qreg qb_62[1]; qreg qb_63[1]; creg cb_0[1]; h qb_59[0]; h qb_60[0]; h qb_61[0]; reset qb_61[0]; h qb_62[0]; reset qb_62[0]; h qb_63[0]; reset qb_63[0]; measure qb_59[0] -> cb_0[0]; reset qb_59[0]; if(cb_0==1) rz(0.5) qb_60[0]; if(cb_0==1) x qb_60[0]; reset qb_60[0]; """ res = self.to_qc(*args) if len(self.outvars) == 1: res = [res] qrisp_qc = res[0] return qrisp_qc.qasm()
[docs] def to_catalyst_jaxpr(self): """ Compiles the jaspr to the corresponding `Catalyst jaxpr <https://docs.pennylane.ai/projects/catalyst/en/stable/index.html>`__. Parameters ---------- *args : tuple The arguments to call the jaspr with. Returns ------- jax.core.Jaxpr The Jaxpr using Catalyst primitives. Examples -------- We create a simple script and inspect the Catalyst Jaxpr: :: from qrisp import * from qrisp.jasp import make_jaspr def example_function(i): qv = QuantumFloat(i) cx(qv[0], qv[1]) t(qv[1]) meas_res = measure(qv) meas_res += 1 return meas_res jaspr = make_jaspr(example_function)(2) print(jaspr.to_catalyst_jaxpr()) # Yields # { lambda ; a:AbstractQreg() b:i64[] c:i32[]. let # d:i64[] = convert_element_type[new_dtype=int64 weak_type=True] c # e:i64[] = add b d # f:i64[] = add b 0 # g:i64[] = add b 1 # h:AbstractQbit() = qextract a f # i:AbstractQbit() = qextract a g # j:AbstractQbit() k:AbstractQbit() = qinst[op=CNOT qubits_len=2] h i # l:AbstractQreg() = qinsert a f j # m:AbstractQreg() = qinsert l g k # n:AbstractQbit() = qextract m g # o:AbstractQbit() = qinst[op=T qubits_len=1] n # p:AbstractQreg() = qinsert m g o # q:i64[] = convert_element_type[new_dtype=int64 weak_type=True] c # r:i64[] = add b q # _:i64[] s:i64[] t:AbstractQreg() _:i64[] _:i64[] = while_loop[ # body_jaxpr={ lambda ; u:i64[] v:i64[] w:AbstractQreg() x:i64[] y:i64[]. let # z:AbstractQbit() = qextract w u # ba:bool[] bb:AbstractQbit() = qmeasure z # bc:AbstractQreg() = qinsert w u bb # bd:i64[] = sub u x # be:i64[] = shift_left 2 bd # bf:i64[] = convert_element_type[new_dtype=int64 weak_type=True] ba # bg:i64[] = mul be bf # bh:i64[] = add v bg # bi:i64[] = add u 1 # in (bi, bh, bc, x, y) } # body_nconsts=0 # cond_jaxpr={ lambda ; bj:i64[] bk:i64[] bl:AbstractQreg() bm:i64[] bn:i64[]. let # bo:bool[] = ge bj bn # in (bo,) } # cond_nconsts=0 # nimplicit=0 # preserve_dimensions=True # ] b 0 p b r # bp:i32[] = convert_element_type[new_dtype=int64 weak_type=False] s # bq:i32[] = mul bp 1 # br:i32[] = add bq 1 # in (t, e, br) } """ from qrisp.jasp.catalyst_interface import jaspr_to_catalyst_jaxpr return jaspr_to_catalyst_jaxpr(self.flatten_environments())
def make_jaspr(fun, garbage_collection = "auto", flatten_envs = True, **jax_kwargs): from qrisp.jasp import AbstractQuantumCircuit, TracingQuantumSession, check_for_tracing_mode from qrisp.core.quantum_variable import QuantumVariable, flatten_qv, unflatten_qv from qrisp.core import recursive_qv_search def jaspr_creator(*args, **kwargs): qs = TracingQuantumSession.get_instance() # Close any tracing quantum sessions that might have not been # properly closed due to whatever reason. if not check_for_tracing_mode(): while qs.abs_qc is not None: qs.conclude_tracing() # This function will be traced by Jax. # Note that we add the abs_qc keyword as the tracing quantum circuit def ammended_function(abs_qc, *args, **kwargs): qs.start_tracing(abs_qc, garbage_collection) # If the signature contains QuantumVariables, these QuantumVariables went # through a flattening/unflattening procedure. The unflattening creates # a copy of the QuantumVariable object, which is however not yet registered in any # QuantumSession. We register these QuantumVariables in the current QuantumSession. arg_qvs = recursive_qv_search(args) for qv in arg_qvs: qs.register_qv(qv, None) try: res = fun(*args, **kwargs) except Exception as e: qs.conclude_tracing() raise e res_qvs = recursive_qv_search(res) qs.garbage_collection(spare_qv_list = arg_qvs + res_qvs) res_qc = qs.conclude_tracing() return res_qc, res closed_jaxpr = make_jaxpr(ammended_function, **jax_kwargs)(AbstractQuantumCircuit(), *args, **kwargs) jaxpr = closed_jaxpr.jaxpr # Collect the environments # This means that the quantum environments no longer appear as # enter/exit primitives but as primitive that "call" a certain Jaspr. res = Jaspr.from_cache(collect_environments(jaxpr)) if flatten_envs: res = res.flatten_environments() res.consts = closed_jaxpr.consts return res # Since we are calling the "ammended function", where the first parameter # is the AbstractQuantumCircuit, we need to move the static_argnums indicator. if "static_argnums" in jax_kwargs: jax_kwargs = dict(jax_kwargs) if isinstance(jax_kwargs["static_argnums"], list): jax_kwargs["static_argnums"] = list(jax_kwargs["static_argnums"]) for i in range(len(jax_kwargs["static_argnums"])): jax_kwargs["static_argnums"][i] += 1 else: jax_kwargs["static_argnums"] += 1 return jaspr_creator def check_aval_equivalence(invars_1, invars_2): avals_1 = [invar.aval for invar in invars_1] avals_2 = [invar.aval for invar in invars_2] return all([type(avals_1[i]) == type(avals_2[i]) for i in range(len(avals_1))])