"""
\********************************************************************************
* Copyright (c) 2025 the Qrisp authors
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* This Source Code may also be made available under the following Secondary
* Licenses when the conditions for such availability set forth in the Eclipse
* Public License, v. 2.0 are satisfied: GNU General Public License, version 2
* with the GNU Classpath Exception which is
* available at https://www.gnu.org/software/classpath/license.html.
*
* SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0
********************************************************************************/
"""
from functools import lru_cache
import jax
from jax import make_jaxpr
from jax.core import Jaxpr, Literal
from jax.tree_util import tree_flatten, tree_unflatten
from qrisp.jasp.jasp_expression import invert_jaspr, collect_environments
from qrisp.jasp import eval_jaxpr, pjit_to_gate, flatten_environments, cond_to_cl_control
from qrisp.jasp.primitives import AbstractQuantumCircuit
[docs]
class Jaspr(Jaxpr):
"""
The ``Jaspr`` class enables an efficient representations of a wide variety
of (hybrid) algorithms. For many applications, the representation is agnostic
to the scale of the problem, implying function calls with 10 or 10000 qubits
can be represented by the same object. The actual unfolding to a circuit-level
description is outsourced to
`established, classical compilation infrastructure <https://mlir.llvm.org/>`_,
implying state-of-the-art compilation speed can be reached.
As a subtype of ``jax.core.Jaxpr``, Jasprs are embedded into the well matured
`Jax ecosystem <https://github.com/n2cholas/awesome-jax>`_,
which facilitates the compilation of classical `real-time computation <https://arxiv.org/abs/2206.12950>`_
using some of the most advanced libraries in the world such as
`CUDA <https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html>`_.
Especially `machine learning <https://ai.google.dev/gemma/docs/jax_inference>`_
and other scientific computations tasks are particularly well supported.
To get a better understanding of the syntax and semantics of Jaxpr (and with
that also Jaspr) please check `this link <https://jax.readthedocs.io/en/latest/jaxpr.html>`__.
Similar to Jaxpr, Jaspr objects represent (hybrid) quantum
algorithms in the form of a `functional programming language <https://en.wikipedia.org/wiki/Functional_programming>`_
in `SSA-form <https://en.wikipedia.org/wiki/Static_single-assignment_form>`_.
It is possible to compile Jaspr objects into QIR, which is facilitated by the
`Catalyst framework <https://docs.pennylane.ai/projects/catalyst/en/stable/index.html>`__
(check :meth:`qrisp.jasp.jaspr.to_qir` for more details).
Qrisp scripts can be turned into Jaspr objects by
calling the ``make_jaspr`` function, which has similar semantics as
`jax.make_jaxpr <https://jax.readthedocs.io/en/latest/_autosummary/jax.make_jaxpr.html>`_.
::
from qrisp import *
from qrisp.jasp import make_jaspr
def test_fun(i):
qv = QuantumFloat(i, -1)
x(qv[0])
cx(qv[0], qv[i-1])
meas_res = measure(qv)
meas_res += 1
return meas_res
jaspr = make_jaspr(test_fun)(4)
print(jaspr)
This will give you the following output:
.. code-block::
{ lambda ; a:QuantumCircuit b:i32[]. let
c:QuantumCircuit d:QubitArray = create_qubits a b
e:Qubit = get_qubit d 0
f:QuantumCircuit = x c e
g:i32[] = sub b 1
h:Qubit = get_qubit d g
i:QuantumCircuit = cx f e h
j:QuantumCircuit k:i32[] = measure i d
l:f32[] = convert_element_type[new_dtype=float64 weak_type=True] k
m:f32[] = mul l 0.5
n:f32[] = add m 1.0
in (j, n) }
A defining feature of the Jaspr class is that the first input and the
first output are always of QuantumCircuit type. Therefore, Jaspr objects always
represent some (hybrid) quantum operation.
Qrisp comes with a built-in Jaspr interpreter. For that you simply have to
call the object like a function:
>>> print(jaspr(2))
2.5
>>> print(jaspr(4))
5.5
"""
__slots__ = "permeability", "isqfree", "hashvalue", "ctrl_jaspr", "envs_flattened", "consts"
def __init__(self, *args, permeability = None, isqfree = None, ctrl_jaspr = None, **kwargs):
if len(args) == 1:
kwargs["jaxpr"] = args[0]
if "jaxpr" in kwargs:
jaxpr = kwargs["jaxpr"]
self.hashvalue = hash(jaxpr)
Jaxpr.__init__(self,
constvars = jaxpr.constvars,
invars = jaxpr.invars,
outvars = jaxpr.outvars,
eqns = jaxpr.eqns,
effects = jaxpr.effects,
debug_info = jaxpr.debug_info
)
else:
self.hashvalue = id(self)
Jaxpr.__init__(self, **kwargs)
self.permeability = {}
if permeability is None:
permeability = {}
for var in self.constvars + self.invars + self.outvars:
if isinstance(var, Literal):
continue
self.permeability[var] = permeability.get(var, None)
self.isqfree = isqfree
self.ctrl_jaspr = ctrl_jaspr
self.envs_flattened = False
self.consts = []
if not isinstance(self.invars[0].aval, AbstractQuantumCircuit):
raise Exception(f"Tried to create a Jaspr from data that doesn't have a QuantumCircuit as first argument (got {type(self.invars[0].aval)} instead)")
if not isinstance(self.outvars[0].aval, AbstractQuantumCircuit):
raise Exception(f"Tried to create a Jaspr from data that doesn't have a QuantumCircuit as first entry of return type (got {type(self.outvars[0].aval)} instead)")
def __hash__(self):
return self.hashvalue
def __eq__(self, other):
if not isinstance(other, Jaxpr):
return False
return id(self) == id(other)
def copy(self):
if self.ctrl_jaspr is None:
ctrl_jaspr = None
else:
ctrl_jaspr = self.ctrl_jaspr.copy()
res = Jaspr(permeability = self.permeability,
isqfree = self.isqfree,
ctrl_jaspr = ctrl_jaspr,
constvars = list(self.constvars),
invars = list(self.invars),
outvars = list(self.outvars),
eqns = list(self.eqns),
effects = self.effects,
debug_info = self.debug_info)
res.envs_flattened = self.envs_flattened
return res
[docs]
def inverse(self):
"""
Returns the inverse Jaspr (if applicable). For Jaspr that contain realtime
computations or measurements, the inverse does not exist.
Returns
-------
Jaspr
The daggered Jaspr.
Examples
--------
We create a simple script and inspect the daggered version:
::
from qrisp import *
from qrisp.jasp import make_jaspr
def example_function(i):
qv = QuantumVariable(i)
cx(qv[0], qv[1])
t(qv[1])
return qv
jaspr = make_jaspr(example_function)(2)
print(jaspr.inverse())
# Yields
# { lambda ; a:QuantumCircuit b:i32[]. let
# c:QuantumCircuit d:QubitArray = create_qubits a b
# e:Qubit = get_qubit d 0
# f:Qubit = get_qubit d 1
# g:QuantumCircuit = t_dg c f
# h:QuantumCircuit = cx g e f
# in (h, d) }
"""
return invert_jaspr(self)
[docs]
def control(self, num_ctrl, ctrl_state = -1):
"""
Returns the controlled version of the Jaspr. The control qubits are added
to the signature of the Jaspr as the arguments after the QuantumCircuit.
Parameters
----------
num_ctrl : int
The amount of controls to be added.
ctrl_state : int of str, optional
The control state on which to activate. The default is -1.
Returns
-------
Jaspr
The controlled Jaspr.
Examples
--------
We create a simple script and inspect the controlled version:
::
from qrisp import *
from qrisp.jasp import make_jaspr
def example_function(i):
qv = QuantumVariable(i)
cx(qv[0], qv[1])
t(qv[1])
return qv
jaspr = make_jaspr(example_function)(2)
print(jaspr.control(2))
# Yields
# { lambda ; a:QuantumCircuit b:Qubit c:Qubit d:i32[]. let
# e:QuantumCircuit f:QubitArray = create_qubits a 1
# g:Qubit = get_qubit f 0
# h:QuantumCircuit = ccx e b c g
# i:QuantumCircuit j:QubitArray = create_qubits h d
# k:Qubit = get_qubit j 0
# l:Qubit = get_qubit j 1
# m:QuantumCircuit = ccx i g k l
# n:QuantumCircuit = ct m g l
# o:QuantumCircuit = ccx n b c g
# in (o, j) }
We see that the control qubits are part of the function signature
(``a`` and ``b``)
"""
if self.ctrl_jaspr is not None and num_ctrl == 1 and ctrl_state == -1:
return self.ctrl_jaspr
from qrisp.jasp import ControlledJaspr
if isinstance(ctrl_state, int):
if ctrl_state < 0:
ctrl_state += 2**num_ctrl
ctrl_state = bin(ctrl_state)[2:].zfill(num_ctrl)
else:
ctrl_state = str(ctrl_state)
return ControlledJaspr.from_cache(self, ctrl_state)
[docs]
def to_qc(self, *args):
"""
Converts the Jaspr into a :ref:`QuantumCircuit` if applicable. Circuit
conversion of algorithms involving realtime computations is not possible.
Parameters
----------
*args : tuple
The arguments to call the Jaspr with.
Returns
-------
:ref:`QuantumCircuit`
The resulting QuantumCircuit.
return_values : tuple
The return values of the Jaspr. QuantumVariable return types are
returned as lists of Qubits.
Examples
--------
We create a simple script and inspect the QuantumCircuit:
::
from qrisp import *
from qrisp.jasp import make_jaspr
def example_function(i):
qv = QuantumVariable(i)
cx(qv[0], qv[1])
t(qv[1])
return qv
jaspr = make_jaspr(example_function)(2)
qc, qb_list = jaspr.to_qc(2)
print(qc)
# Yields
# qb_0: ──■───────
# ┌─┴─┐┌───┐
# qb_1: ┤ X ├┤ T ├
# └───┘└───┘
"""
from qrisp import QuantumCircuit, Clbit
jaspr = self
def eqn_evaluator(eqn, context_dic):
if eqn.primitive.name == "pjit" and isinstance(eqn.params["jaxpr"].jaxpr, Jaspr):
return pjit_to_gate(eqn, context_dic, eqn_evaluator)
elif eqn.primitive.name == "cond":
return cond_to_cl_control(eqn, context_dic, eqn_evaluator)
elif eqn.primitive.name == "convert_element_type":
if isinstance(context_dic[eqn.invars[0]], Clbit):
context_dic[eqn.outvars[0]] = context_dic[eqn.invars[0]]
return
return True
res = eval_jaxpr(jaspr, eqn_evaluator = eqn_evaluator)(*([QuantumCircuit()] + list(args)))
return res
def eval(self, *args, eqn_evaluator = lambda x, y : True):
return eval_jaxpr(self, eqn_evaluator = eqn_evaluator)(*args)
def flatten_environments(self):
"""
Flattens all environments by applying the corresponding compilation
routines such that no more ``q_env`` primitives are left.
Returns
-------
Jaspr
The Jaspr with flattened environments.
Examples
--------
We create a Jaspr containing an :ref:`InversionEnvironment` and flatten:
::
def test_function(i):
qv = QuantumVariable(i)
with invert():
t(qv[0])
cx(qv[0], qv[1])
return qv
jaspr = make_jaspr(test_function)(2)
print(jaspr)
::
{ lambda ; a:QuantumCircuit b:i32[]. let
c:QuantumCircuit d:QubitArray = create_qubits a b
e:QuantumCircuit = q_env[
jaspr={ lambda ; f:QuantumCircuit d:QubitArray. let
g:Qubit = get_qubit d 0
h:QuantumCircuit = t f g
i:Qubit = get_qubit d 1
j:QuantumCircuit = cx h g i
in (j,) }
type=InversionEnvironment
] c d
in (e, d) }
You can see how the body of the :ref:`InversionEnvironment` is __collected__
into another Jaspr. This reflects the fact that at their core,
:ref:`QuantumEnvironment <QuantumEnvironment>` describe `higher-order
quantum functions <https://en.wikipedia.org/wiki/Higher-order_function>`_
(ie. functions that operate on functions). In order to apply the
transformations induced by the QuantumEnvironment, we can call
``jaspr.flatten_environments``:
>>> print(jaspr.flatten_environments)
{ lambda ; a:QuantumCircuit b:i32[]. let
c:QuantumCircuit d:QubitArray = create_qubits a b
e:Qubit = get_qubit d 0
f:Qubit = get_qubit d 1
g:QuantumCircuit = cx c e f
h:QuantumCircuit = t_dg g e
in (h, d) }
We see that as expected, the order of the ``cx`` and the ``t`` gate has been switched and the ``t`` gate has been turned into a ``t_dg``.
"""
res = flatten_environments(self)
if self.ctrl_jaspr is not None:
res.ctrl_jaspr = self.ctrl_jaspr.flatten_environments()
return res
def __call__(self, *args):
from qrisp.jasp.evaluation_tools.jaspification import simulate_jaspr
return simulate_jaspr(self, *args)
if len(self.outvars) == 1:
return None
from jax.tree_util import tree_flatten
from qrisp.simulator import BufferedQuantumState
args = [BufferedQuantumState()] + list(tree_flatten(args)[0])
from qrisp.jasp import extract_invalues, insert_outvalues, eval_jaxpr
flattened_jaspr = self
def eqn_evaluator(eqn, context_dic):
if eqn.primitive.name == "pjit":
if eqn.params["name"] == "expectation_value_eval_function":
from qrisp.jasp.program_control import sampling_evaluator
sampling_evaluator("ev")(eqn, context_dic, eqn_evaluator = eqn_evaluator)
return
if eqn.params["name"] == "sampling_eval_function":
from qrisp.jasp.program_control import sampling_evaluator
sampling_evaluator("array")(eqn, context_dic, eqn_evaluator = eqn_evaluator)
return
invalues = extract_invalues(eqn, context_dic)
outvalues = eval_jaxpr(eqn.params["jaxpr"], eqn_evaluator = eqn_evaluator)(*invalues)
if not isinstance(outvalues, (list, tuple)):
outvalues = [outvalues]
insert_outvalues(eqn, context_dic, outvalues)
elif eqn.primitive.name == "jasp.quantum_kernel":
insert_outvalues(eqn, context_dic, BufferedQuantumState())
else:
return True
res = eval_jaxpr(flattened_jaspr, eqn_evaluator = eqn_evaluator)(*(args + self.consts))
if len(self.outvars) == 2:
return res[1]
else:
return res[1:]
def inline(self, *args):
from qrisp.jasp import TracingQuantumSession
qs = TracingQuantumSession.get_instance()
abs_qc = qs.abs_qc
res = eval_jaxpr(self)(*([abs_qc] + list(args)))
if isinstance(res, tuple):
new_abs_qc = res[0]
res = res[1:]
else:
new_abs_qc = res
res = None
qs.abs_qc = new_abs_qc
return res
def embedd(self, *args, name = None, inline = False):
from qrisp.jasp import TracingQuantumSession
qs = TracingQuantumSession.get_instance()
abs_qc = qs.abs_qc
if not inline:
res = jax.jit(eval_jaxpr(self))(*([abs_qc] + list(args)))
eqn = jax._src.core.thread_local_state.trace_state.trace_stack.dynamic.jaxpr_stack[0].eqns[-1]
eqn.params["jaxpr"] = jax.core.ClosedJaxpr(self, eqn.params["jaxpr"].consts)
if name is not None:
eqn.params["name"] = name
else:
res = eval_jaxpr(self)(*([abs_qc] + list(args)))
if isinstance(res, tuple):
new_abs_qc = res[0]
res = res[1:]
else:
new_abs_qc = res
res = None
qs.abs_qc = new_abs_qc
return res
[docs]
def qjit(self, *args, function_name = "jaspr_function"):
"""
Leverages the Catalyst pipeline to compile a QIR representation of
this function and executes that function using the Catalyst QIR runtime.
Parameters
----------
*args : iterable
The arguments to call the function with.
Returns
-------
The values returned by the compiled, executed function.
"""
flattened_jaspr = self
from qrisp.jasp.evaluation_tools.catalyst_interface import jaspr_to_catalyst_qjit
qjit_obj = jaspr_to_catalyst_qjit(flattened_jaspr, function_name = function_name)
res = qjit_obj.compiled_function(*args)
if not isinstance(res, (tuple,list)):
return res
elif len(res) == 1:
return res[0]
else:
return res
@classmethod
@lru_cache(maxsize = int(1E5))
def from_cache(cls, jaxpr):
return Jaspr(jaxpr = jaxpr)
def update_eqns(self, eqns):
return Jaspr(constvars = list(self.constvars),
invars = list(self.invars),
outvars = list(self.outvars),
eqns = list(eqns))
[docs]
def to_qir(self):
"""
Compiles the Jaspr to QIR using the `Catalyst framework <https://docs.pennylane.ai/projects/catalyst/en/stable/index.html>`__.
Parameters
----------
None
Returns
-------
str
The QIR string.
Examples
--------
We create a simple script and inspect the QIR string:
::
from qrisp import *
from qrisp.jasp import make_jaspr
def example_function(i):
qv = QuantumFloat(i)
cx(qv[0], qv[1])
t(qv[1])
meas_res = measure(qv)
meas_res += 1
return meas_res
jaspr = make_jaspr(example_function)(2)
print(jaspr.to_qir())
Yields:
::
; ModuleID = 'LLVMDialectModule'
source_filename = "LLVMDialectModule"
@"{'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None}" = internal constant [66 x i8] c"{'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None}\00"
@lightning.qubit = internal constant [16 x i8] c"lightning.qubit\00"
@"/home/positr0nium/miniconda3/envs/qrisp/lib/python3.10/site-packages/catalyst/utils/../lib/librtd_lightning.so" = internal constant [111 x i8] c"/home/positr0nium/miniconda3/envs/qrisp/lib/python3.10/site-packages/catalyst/utils/../lib/librtd_lightning.so\00"
declare void @__catalyst__rt__finalize() local_unnamed_addr
declare void @__catalyst__rt__initialize() local_unnamed_addr
declare ptr @__catalyst__qis__Measure(ptr, i32) local_unnamed_addr
declare void @__catalyst__qis__T(ptr, ptr) local_unnamed_addr
declare void @__catalyst__qis__CNOT(ptr, ptr, ptr) local_unnamed_addr
declare ptr @__catalyst__rt__array_get_element_ptr_1d(ptr, i64) local_unnamed_addr
declare ptr @__catalyst__rt__qubit_allocate_array(i64) local_unnamed_addr
declare void @__catalyst__rt__device_init(ptr, ptr, ptr) local_unnamed_addr
declare void @_mlir_memref_to_llvm_free(ptr) local_unnamed_addr
declare ptr @_mlir_memref_to_llvm_alloc(i64) local_unnamed_addr
define { ptr, ptr, i64 } @jit_jaspr_function(ptr nocapture readnone %0, ptr nocapture readonly %1, i64 %2) local_unnamed_addr {
tail call void @__catalyst__rt__device_init(ptr nonnull @"/home/positr0nium/miniconda3/envs/qrisp/lib/python3.10/site-packages/catalyst/utils/../lib/librtd_lightning.so", ptr nonnull @lightning.qubit, ptr nonnull @"{'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None}")
%4 = tail call ptr @__catalyst__rt__qubit_allocate_array(i64 20)
%5 = tail call ptr @__catalyst__rt__array_get_element_ptr_1d(ptr %4, i64 0)
%6 = load ptr, ptr %5, align 8
%7 = tail call ptr @__catalyst__rt__array_get_element_ptr_1d(ptr %4, i64 1)
%8 = load ptr, ptr %7, align 8
tail call void @__catalyst__qis__CNOT(ptr %6, ptr %8, ptr null)
%9 = tail call ptr @__catalyst__rt__array_get_element_ptr_1d(ptr %4, i64 1)
%10 = load ptr, ptr %9, align 8
tail call void @__catalyst__qis__T(ptr %10, ptr null)
%11 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 65)
%12 = ptrtoint ptr %11 to i64
%13 = add i64 %12, 63
%14 = and i64 %13, -64
%15 = inttoptr i64 %14 to ptr
%16 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 65)
%17 = ptrtoint ptr %16 to i64
%18 = add i64 %17, 63
%19 = and i64 %18, -64
%20 = inttoptr i64 %19 to ptr
%21 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 72)
%22 = ptrtoint ptr %21 to i64
%23 = add i64 %22, 63
%24 = and i64 %23, -64
%25 = inttoptr i64 %24 to ptr
%26 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 72)
%27 = ptrtoint ptr %26 to i64
%28 = add i64 %27, 63
%29 = and i64 %28, -64
%30 = inttoptr i64 %29 to ptr
%31 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 72)
%32 = ptrtoint ptr %31 to i64
%33 = add i64 %32, 63
%34 = and i64 %33, -64
%35 = inttoptr i64 %34 to ptr
%36 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 72)
%37 = ptrtoint ptr %36 to i64
%38 = add i64 %37, 63
%39 = and i64 %38, -64
%40 = inttoptr i64 %39 to ptr
%41 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 8)
store i64 0, ptr %41, align 1
%42 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 8)
store i64 0, ptr %42, align 1
%43 = load i64, ptr %1, align 4
%44 = icmp slt i64 %43, 1
store i1 %44, ptr %15, align 64
%45 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 8)
%46 = load i64, ptr %42, align 1
store i64 %46, ptr %45, align 1
tail call void @_mlir_memref_to_llvm_free(ptr nonnull %42)
%47 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 8)
%48 = load i64, ptr %41, align 1
store i64 %48, ptr %47, align 1
tail call void @_mlir_memref_to_llvm_free(ptr nonnull %41)
br i1 %44, label %.lr.ph, label %._crit_edge
.lr.ph: ; preds = %3, %.lr.ph
%49 = phi ptr [ %87, %.lr.ph ], [ %47, %3 ]
%50 = phi ptr [ %85, %.lr.ph ], [ %45, %3 ]
%51 = load i64, ptr %50, align 4
%52 = tail call ptr @__catalyst__rt__array_get_element_ptr_1d(ptr %4, i64 %51)
%53 = load ptr, ptr %52, align 8
%54 = tail call ptr @__catalyst__qis__Measure(ptr %53, i32 -1)
%55 = load i1, ptr %54, align 1
store i1 %55, ptr %20, align 64
%56 = load i64, ptr %50, align 4
store i64 %56, ptr %25, align 64
%57 = shl i64 2, %56
%58 = icmp ult i64 %56, 64
%59 = select i1 %58, i64 %57, i64 0
store i64 %59, ptr %30, align 64
%60 = load i1, ptr %20, align 64
%61 = zext i1 %60 to i64
store i64 %61, ptr %35, align 64
%62 = load i64, ptr %30, align 64
%63 = select i1 %60, i64 %62, i64 0
store i64 %63, ptr %40, align 64
%64 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 72)
%65 = ptrtoint ptr %64 to i64
%66 = add i64 %65, 63
%67 = and i64 %66, -64
%68 = inttoptr i64 %67 to ptr
%69 = load i64, ptr %49, align 4
%70 = load i64, ptr %40, align 64
%71 = add i64 %70, %69
store i64 %71, ptr %68, align 64
tail call void @_mlir_memref_to_llvm_free(ptr nonnull %49)
%72 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 72)
%73 = ptrtoint ptr %72 to i64
%74 = add i64 %73, 63
%75 = and i64 %74, -64
%76 = inttoptr i64 %75 to ptr
%77 = load i64, ptr %50, align 4
%78 = add i64 %77, 1
store i64 %78, ptr %76, align 64
tail call void @_mlir_memref_to_llvm_free(ptr nonnull %50)
%79 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 8)
%80 = load i64, ptr %68, align 64
store i64 %80, ptr %79, align 1
tail call void @_mlir_memref_to_llvm_free(ptr %64)
%81 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 8)
%82 = load i64, ptr %76, align 64
store i64 %82, ptr %81, align 1
tail call void @_mlir_memref_to_llvm_free(ptr %72)
%.pre = load i64, ptr %81, align 4
%83 = load i64, ptr %1, align 4
%84 = icmp sge i64 %.pre, %83
store i1 %84, ptr %15, align 64
%85 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 8)
%86 = load i64, ptr %81, align 1
store i64 %86, ptr %85, align 1
tail call void @_mlir_memref_to_llvm_free(ptr nonnull %81)
%87 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 8)
%88 = load i64, ptr %79, align 1
store i64 %88, ptr %87, align 1
tail call void @_mlir_memref_to_llvm_free(ptr nonnull %79)
br i1 %84, label %.lr.ph, label %._crit_edge
._crit_edge: ; preds = %.lr.ph, %3
%.lcssa20 = phi ptr [ %45, %3 ], [ %85, %.lr.ph ]
%.lcssa = phi ptr [ %47, %3 ], [ %87, %.lr.ph ]
tail call void @_mlir_memref_to_llvm_free(ptr nonnull %.lcssa20)
tail call void @_mlir_memref_to_llvm_free(ptr %36)
tail call void @_mlir_memref_to_llvm_free(ptr %31)
tail call void @_mlir_memref_to_llvm_free(ptr %26)
tail call void @_mlir_memref_to_llvm_free(ptr %21)
tail call void @_mlir_memref_to_llvm_free(ptr %16)
tail call void @_mlir_memref_to_llvm_free(ptr %11)
%89 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 68)
%90 = ptrtoint ptr %89 to i64
%91 = add i64 %90, 63
%92 = and i64 %91, -64
%93 = inttoptr i64 %92 to ptr
%94 = load i64, ptr %.lcssa, align 4
%95 = trunc i64 %94 to i32
store i32 %95, ptr %93, align 64
tail call void @_mlir_memref_to_llvm_free(ptr nonnull %.lcssa)
%96 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 68)
%97 = ptrtoint ptr %96 to i64
%98 = add i64 %97, 63
%99 = and i64 %98, -64
%100 = inttoptr i64 %99 to ptr
%101 = load i32, ptr %93, align 64
%102 = add i32 %101, 1
store i32 %102, ptr %100, align 64
tail call void @_mlir_memref_to_llvm_free(ptr %89)
%103 = icmp eq ptr %96, inttoptr (i64 3735928559 to ptr)
br i1 %103, label %104, label %107
104: ; preds = %._crit_edge
%105 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 4)
%106 = load i32, ptr %100, align 64
store i32 %106, ptr %105, align 1
br label %107
107: ; preds = %104, %._crit_edge
%.pn16 = phi ptr [ %105, %104 ], [ %96, %._crit_edge ]
%.pn14 = phi ptr [ %105, %104 ], [ %100, %._crit_edge ]
%.pn13 = insertvalue { ptr, ptr, i64 } undef, ptr %.pn16, 0
%.pn = insertvalue { ptr, ptr, i64 } %.pn13, ptr %.pn14, 1
%108 = insertvalue { ptr, ptr, i64 } %.pn, i64 0, 2
ret { ptr, ptr, i64 } %108
}
define void @_catalyst_pyface_jit_jaspr_function(ptr nocapture writeonly %0, ptr nocapture readonly %1) local_unnamed_addr {
%.unpack = load ptr, ptr %1, align 8
%.elt1.i = getelementptr inbounds { ptr, ptr, i64 }, ptr %.unpack, i64 0, i32 1
%.unpack2.i = load ptr, ptr %.elt1.i, align 8
%3 = tail call { ptr, ptr, i64 } @jit_jaspr_function(ptr poison, ptr %.unpack2.i, i64 poison)
%.elt.i = extractvalue { ptr, ptr, i64 } %3, 0
store ptr %.elt.i, ptr %0, align 8
%.repack5.i = getelementptr inbounds { ptr, ptr, i64 }, ptr %0, i64 0, i32 1
%.elt6.i = extractvalue { ptr, ptr, i64 } %3, 1
store ptr %.elt6.i, ptr %.repack5.i, align 8
%.repack7.i = getelementptr inbounds { ptr, ptr, i64 }, ptr %0, i64 0, i32 2
%.elt8.i = extractvalue { ptr, ptr, i64 } %3, 2
store i64 %.elt8.i, ptr %.repack7.i, align 8
ret void
}
define void @_catalyst_ciface_jit_jaspr_function(ptr nocapture writeonly %0, ptr nocapture readonly %1) local_unnamed_addr {
%.elt1 = getelementptr inbounds { ptr, ptr, i64 }, ptr %1, i64 0, i32 1
%.unpack2 = load ptr, ptr %.elt1, align 8
%3 = tail call { ptr, ptr, i64 } @jit_jaspr_function(ptr poison, ptr %.unpack2, i64 poison)
%.elt = extractvalue { ptr, ptr, i64 } %3, 0
store ptr %.elt, ptr %0, align 8
%.repack5 = getelementptr inbounds { ptr, ptr, i64 }, ptr %0, i64 0, i32 1
%.elt6 = extractvalue { ptr, ptr, i64 } %3, 1
store ptr %.elt6, ptr %.repack5, align 8
%.repack7 = getelementptr inbounds { ptr, ptr, i64 }, ptr %0, i64 0, i32 2
%.elt8 = extractvalue { ptr, ptr, i64 } %3, 2
store i64 %.elt8, ptr %.repack7, align 8
ret void
}
define void @setup() local_unnamed_addr {
tail call void @__catalyst__rt__initialize()
ret void
}
define void @teardown() local_unnamed_addr {
tail call void @__catalyst__rt__finalize()
ret void
}
!llvm.module.flags = !{!0}
!0 = !{i32 2, !"Debug Info Version", i32 3}
"""
from qrisp.jasp.evaluation_tools.catalyst_interface import jaspr_to_qir
return jaspr_to_qir(self.flatten_environments())
[docs]
def to_mlir(self):
"""
Compiles the Jaspr to MLIR using the `Catalyst dialect <https://docs.pennylane.ai/projects/catalyst/en/stable/index.html>`__.
Parameters
----------
None
Returns
-------
str
The MLIR string.
Examples
--------
We create a simple script and inspect the MLIR string:
::
from qrisp import *
from qrisp.jasp import make_jaspr
def example_function(i):
qv = QuantumFloat(i)
cx(qv[0], qv[1])
t(qv[1])
meas_res = measure(qv)
meas_res += 1
return meas_res
jaspr = make_jaspr(example_function)(2)
print(jaspr.to_mlir())
::
module @jaspr_function {
func.func public @jit_jaspr_function(%arg0: tensor<i64>) -> tensor<i32> attributes {llvm.emit_c_interface} {
%0 = stablehlo.constant dense<1> : tensor<i32>
%1 = stablehlo.constant dense<2> : tensor<i64>
%2 = stablehlo.constant dense<1> : tensor<i64>
%3 = stablehlo.constant dense<0> : tensor<i64>
quantum.device["/home/positr0nium/miniconda3/envs/qrisp/lib/python3.10/site-packages/catalyst/utils/../lib/librtd_lightning.so", "lightning.qubit", "{'shots': 0, 'mcmc': False, 'num_burnin': 0, 'kernel_name': None}"]
%4 = quantum.alloc( 20) : !quantum.reg
%5 = quantum.extract %4[ 0] : !quantum.reg -> !quantum.bit
%6 = quantum.extract %4[ 1] : !quantum.reg -> !quantum.bit
%out_qubits:2 = quantum.custom "CNOT"() %5, %6 : !quantum.bit, !quantum.bit
%7 = quantum.insert %4[ 0], %out_qubits#0 : !quantum.reg, !quantum.bit
%8 = quantum.insert %7[ 1], %out_qubits#1 : !quantum.reg, !quantum.bit
%9 = quantum.extract %8[ 1] : !quantum.reg -> !quantum.bit
%out_qubits_0 = quantum.custom "T"() %9 : !quantum.bit
%10 = quantum.insert %8[ 1], %out_qubits_0 : !quantum.reg, !quantum.bit
%11 = stablehlo.add %3, %arg0 : tensor<i64>
%12:3 = scf.while (%arg1 = %3, %arg2 = %3, %arg3 = %10) : (tensor<i64>, tensor<i64>, !quantum.reg) -> (tensor<i64>, tensor<i64>, !quantum.reg) {
%16 = stablehlo.compare GE, %arg1, %11, SIGNED : (tensor<i64>, tensor<i64>) -> tensor<i1>
%extracted = tensor.extract %16[] : tensor<i1>
scf.condition(%extracted) %arg1, %arg2, %arg3 : tensor<i64>, tensor<i64>, !quantum.reg
} do {
^bb0(%arg1: tensor<i64>, %arg2: tensor<i64>, %arg3: !quantum.reg):
%extracted = tensor.extract %arg1[] : tensor<i64>
%16 = quantum.extract %arg3[%extracted] : !quantum.reg -> !quantum.bit
%mres, %out_qubit = quantum.measure %16 : i1, !quantum.bit
%from_elements = tensor.from_elements %mres : tensor<i1>
%extracted_1 = tensor.extract %arg1[] : tensor<i64>
%17 = quantum.insert %arg3[%extracted_1], %out_qubit : !quantum.reg, !quantum.bit
%18 = stablehlo.subtract %arg1, %3 : tensor<i64>
%19 = stablehlo.shift_left %1, %18 : tensor<i64>
%20 = stablehlo.convert %from_elements : (tensor<i1>) -> tensor<i64>
%21 = stablehlo.multiply %19, %20 : tensor<i64>
%22 = stablehlo.add %arg2, %21 : tensor<i64>
%23 = stablehlo.add %arg1, %2 : tensor<i64>
scf.yield %23, %22, %17 : tensor<i64>, tensor<i64>, !quantum.reg
}
%13 = stablehlo.convert %12#1 : (tensor<i64>) -> tensor<i32>
%14 = stablehlo.multiply %13, %0 : tensor<i32>
%15 = stablehlo.add %14, %0 : tensor<i32>
return %15 : tensor<i32>
}
func.func @setup() {
quantum.init
return
}
func.func @teardown() {
quantum.finalize
return
}
}
"""
from qrisp.jasp.evaluation_tools.catalyst_interface import jaspr_to_mlir
return jaspr_to_mlir(self.flatten_environments())
[docs]
def to_qasm(self, *args):
"""
Compiles the Jaspr into an OpenQASM 2 string. Real-time control is possible
as long as no computations on the measurement results are performed.
Parameters
----------
*args : list
The arguments to call the :ref:`QuantumCircuit` evaluation with.
Returns
-------
str
The OpenQASM 2 string.
Examples
--------
We create a simple script and inspect the QASM 2 string:
::
from qrisp import *
from qrisp.jasp import make_jaspr
def main(i):
qv = QuantumVariable(i)
cx(qv[0], qv[1])
t(qv[1])
return qv
jaspr = make_jaspr(main)(2)
qasm_str = jaspr.to_qasm(2)
print(qasm_str)
# Yields
# OPENQASM 2.0;
# include "qelib1.inc";
# qreg qb_59[1];
# qreg qb_60[1];
# cx qb_59[0],qb_60[0];
# t qb_60[0];
It is also possible to compile simple real-time control features:
::
def main(phi):
qf = QuantumFloat(5)
h(qf)
bl = measure(qf[0])
with control(bl):
rz(phi, qf[1])
x(qf[1])
return
jaspr = make_jaspr(main)(0.5)
print(jaspr.to_qasm(0.5))
This gives:
::
OPENQASM 2.0;
include "qelib1.inc";
qreg qb_59[1];
qreg qb_60[1];
qreg qb_61[1];
qreg qb_62[1];
qreg qb_63[1];
creg cb_0[1];
h qb_59[0];
h qb_60[0];
h qb_61[0];
reset qb_61[0];
h qb_62[0];
reset qb_62[0];
h qb_63[0];
reset qb_63[0];
measure qb_59[0] -> cb_0[0];
reset qb_59[0];
if(cb_0==1) rz(0.5) qb_60[0];
if(cb_0==1) x qb_60[0];
reset qb_60[0];
"""
res = self.to_qc(*args)
if len(self.outvars) == 1:
res = [res]
qrisp_qc = res[0]
return qrisp_qc.qasm()
[docs]
def to_catalyst_jaxpr(self):
"""
Compiles the jaspr to the corresponding `Catalyst jaxpr <https://docs.pennylane.ai/projects/catalyst/en/stable/index.html>`__.
Parameters
----------
*args : tuple
The arguments to call the jaspr with.
Returns
-------
jax.core.Jaxpr
The Jaxpr using Catalyst primitives.
Examples
--------
We create a simple script and inspect the Catalyst Jaxpr:
::
from qrisp import *
from qrisp.jasp import make_jaspr
def example_function(i):
qv = QuantumFloat(i)
cx(qv[0], qv[1])
t(qv[1])
meas_res = measure(qv)
meas_res += 1
return meas_res
jaspr = make_jaspr(example_function)(2)
print(jaspr.to_catalyst_jaxpr())
# Yields
# { lambda ; a:AbstractQreg() b:i64[] c:i32[]. let
# d:i64[] = convert_element_type[new_dtype=int64 weak_type=True] c
# e:i64[] = add b d
# f:i64[] = add b 0
# g:i64[] = add b 1
# h:AbstractQbit() = qextract a f
# i:AbstractQbit() = qextract a g
# j:AbstractQbit() k:AbstractQbit() = qinst[op=CNOT qubits_len=2] h i
# l:AbstractQreg() = qinsert a f j
# m:AbstractQreg() = qinsert l g k
# n:AbstractQbit() = qextract m g
# o:AbstractQbit() = qinst[op=T qubits_len=1] n
# p:AbstractQreg() = qinsert m g o
# q:i64[] = convert_element_type[new_dtype=int64 weak_type=True] c
# r:i64[] = add b q
# _:i64[] s:i64[] t:AbstractQreg() _:i64[] _:i64[] = while_loop[
# body_jaxpr={ lambda ; u:i64[] v:i64[] w:AbstractQreg() x:i64[] y:i64[]. let
# z:AbstractQbit() = qextract w u
# ba:bool[] bb:AbstractQbit() = qmeasure z
# bc:AbstractQreg() = qinsert w u bb
# bd:i64[] = sub u x
# be:i64[] = shift_left 2 bd
# bf:i64[] = convert_element_type[new_dtype=int64 weak_type=True] ba
# bg:i64[] = mul be bf
# bh:i64[] = add v bg
# bi:i64[] = add u 1
# in (bi, bh, bc, x, y) }
# body_nconsts=0
# cond_jaxpr={ lambda ; bj:i64[] bk:i64[] bl:AbstractQreg() bm:i64[] bn:i64[]. let
# bo:bool[] = ge bj bn
# in (bo,) }
# cond_nconsts=0
# nimplicit=0
# preserve_dimensions=True
# ] b 0 p b r
# bp:i32[] = convert_element_type[new_dtype=int64 weak_type=False] s
# bq:i32[] = mul bp 1
# br:i32[] = add bq 1
# in (t, e, br) }
"""
from qrisp.jasp.catalyst_interface import jaspr_to_catalyst_jaxpr
return jaspr_to_catalyst_jaxpr(self.flatten_environments())
def make_jaspr(fun, garbage_collection = "auto", flatten_envs = True, **jax_kwargs):
from qrisp.jasp import AbstractQuantumCircuit, TracingQuantumSession, check_for_tracing_mode
from qrisp.core.quantum_variable import QuantumVariable, flatten_qv, unflatten_qv
from qrisp.core import recursive_qv_search
def jaspr_creator(*args, **kwargs):
qs = TracingQuantumSession.get_instance()
# Close any tracing quantum sessions that might have not been
# properly closed due to whatever reason.
if not check_for_tracing_mode():
while qs.abs_qc is not None:
qs.conclude_tracing()
# This function will be traced by Jax.
# Note that we add the abs_qc keyword as the tracing quantum circuit
def ammended_function(abs_qc, *args, **kwargs):
qs.start_tracing(abs_qc, garbage_collection)
# If the signature contains QuantumVariables, these QuantumVariables went
# through a flattening/unflattening procedure. The unflattening creates
# a copy of the QuantumVariable object, which is however not yet registered in any
# QuantumSession. We register these QuantumVariables in the current QuantumSession.
arg_qvs = recursive_qv_search(args)
for qv in arg_qvs:
qs.register_qv(qv, None)
try:
res = fun(*args, **kwargs)
except Exception as e:
qs.conclude_tracing()
raise e
res_qvs = recursive_qv_search(res)
qs.garbage_collection(spare_qv_list = arg_qvs + res_qvs)
res_qc = qs.conclude_tracing()
return res_qc, res
closed_jaxpr = make_jaxpr(ammended_function, **jax_kwargs)(AbstractQuantumCircuit(), *args, **kwargs)
jaxpr = closed_jaxpr.jaxpr
# Collect the environments
# This means that the quantum environments no longer appear as
# enter/exit primitives but as primitive that "call" a certain Jaspr.
res = Jaspr.from_cache(collect_environments(jaxpr))
if flatten_envs:
res = res.flatten_environments()
res.consts = closed_jaxpr.consts
return res
# Since we are calling the "ammended function", where the first parameter
# is the AbstractQuantumCircuit, we need to move the static_argnums indicator.
if "static_argnums" in jax_kwargs:
jax_kwargs = dict(jax_kwargs)
if isinstance(jax_kwargs["static_argnums"], list):
jax_kwargs["static_argnums"] = list(jax_kwargs["static_argnums"])
for i in range(len(jax_kwargs["static_argnums"])):
jax_kwargs["static_argnums"][i] += 1
else:
jax_kwargs["static_argnums"] += 1
return jaspr_creator
def check_aval_equivalence(invars_1, invars_2):
avals_1 = [invar.aval for invar in invars_1]
avals_2 = [invar.aval for invar in invars_2]
return all([type(avals_1[i]) == type(avals_2[i]) for i in range(len(avals_1))])