Source code for pennylane.capture.subroutine
# Copyright 2025 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Here we define a mechanism for capturing subroutines by patching the pjit primitive.
While we need to come back and develop custom handling that does not involve patching
jax internals, this will let us build on it for the time being.
We could also just develop a custom higher order primitive like all our other higher order
primitives, but we currently want to be able to cache the jaxpr and the lowering and to
be able to avoid promoting constants to the outer scope. Solving these would take
time we don't have.
We also can't just use the normal ``jit`` primitive, because we currently need to know
which higher order primitive needs to have QReg's added to it's inputs and removed from
it's outputs in Catalyst's ``from_plxpr``.
The steps involved in lowering a subroutine include adding a quantum register input and output, and translating the inside code from plxpr to catalyst jaxpr. The registers are also handled in the aforementioned Catalyst frontend.
Note that this explanation will probably get out of date.
"""
import copy
from .autograph import wraps
from .patching import Patcher
from .switches import enabled
has_jax = True
try:
import jax
from jax._src.pjit import jit_p as pjit_p
quantum_subroutine_prim = copy.deepcopy(pjit_p)
quantum_subroutine_prim.name = "quantum_subroutine_prim"
except ImportError: # pragma: no cover
has_jax = False
quantum_subroutine_prim = None
[docs]
def subroutine(func, static_argnums=None, static_argnames=None):
"""
Denotes the creation of a function in the intermediate representation.
May be used to reduce compilation times. Instead of repeatedly compiling
inlined versions of the function passed as a parameter, when functions
are annotated with a subroutine, a single version of the function
will be compiled and called from potentially multiple callsites.
.. note::
Subroutines are only available when using the PLxPR program capture
interface.
Args:
subroutine (Callable): the function
static_argnums (None | int | Sequence[int]): the indices of the static arguments
static_argnames (None | str | Sequence[str]): the names of static arguments. May be
provided instead of ``static_argnums`` for readability.
**Example**
.. code-block:: python
qml.capture.enable()
@qml.capture.subroutine
def f(x, wires):
qml.RX(x, wires)
@qml.qnode(qml.device('lightning.qubit', wires=5))
def c(x : float):
f(x, 0)
f(x, 1)
return qml.state()
print(jax.make_jaxpr(c)(0.5))
.. code-block::
let f = { lambda ; a:f64[] b:i64[]. let
_:AbstractOperator() = RX[n_wires=1] a b
in () } in
{ lambda ; c:f64[]. let
d:c128[32] = qnode[
device=<lightning.qubit device (wires=5) at 0x12aac1c40>
execution_config=ExecutionConfig(grad_on_execution=False, use_device_gradient=None, use_device_jacobian_product=False, gradient_method='best', gradient_keyword_arguments={}, device_options={}, interface=<Interface.JAX: 'jax'>, derivative_order=1, mcm_config=MCMConfig(mcm_method=None, postselect_mode=None), convert_to_numpy=True, executor_backend=<class 'pennylane.concurrency.executors.native.multiproc.MPPoolExec'>)
n_consts=0
qfunc_jaxpr={ lambda ; e:f64[]. let
quantum_subroutine_p[
compiler_options_kvs=()
ctx_mesh=Mesh(, axis_types=())
donated_invars=(False, False)
in_layouts=(None, None)
in_shardings=(UnspecifiedValue, UnspecifiedValue)
inline=False
jaxpr=f
keep_unused=False
name=f
out_layouts=()
out_shardings=()
] e 0:i64[]
quantum_subroutine_p[
compiler_options_kvs=()
ctx_mesh=Mesh(, axis_types=())
donated_invars=(False, False)
in_layouts=(None, None)
in_shardings=(UnspecifiedValue, UnspecifiedValue)
inline=False
jaxpr=f
keep_unused=False
name=f
out_layouts=()
out_shardings=()
] e 1:i64[]
g:AbstractMeasurement(n_wires=0) = state_wires
in (g,) }
qnode=<QNode: device='<lightning.qubit device (wires=5) at 0x12aac1c40>', interface='jax', diff_method='best', shots='Shots(total=None)'>
shots_len=0
] c
in (d,)
If we create a ``qjit`` version of the QNode, we can inspect the mlir and see a ``FuncOp`` that is
reused for both calls:
>>> qjit_c = qml.qjit(c)
>>> print(qjit_c.mlir[1010:1300]) # doctest: +SKIP
%0 = quantum.alloc( 5) : !quantum.reg
%1 = call @f(%0, %arg0, %c_0) : (!quantum.reg, tensor<f64>, tensor<i64>) -> !quantum.reg
%2 = call @f(%1, %arg0, %c) : (!quantum.reg, tensor<f64>, tensor<i64>) -> !quantum.reg
%3 = quantum.compbasis qreg %2 : !quantum.obs
>>> print(qjit_c.mlir[1465:2070]) # doctest: +SKIP
func.func private @f(%arg0: !quantum.reg, %arg1: tensor<f64>, %arg2: tensor<i64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage<internal>} {
%extracted = tensor.extract %arg2[] : tensor<i64>
%0 = quantum.extract %arg0[%extracted] : !quantum.reg -> !quantum.bit
%extracted_0 = tensor.extract %arg1[] : tensor<f64>
%out_qubits = quantum.custom "RX"(%extracted_0) %0 : !quantum.bit
%extracted_1 = tensor.extract %arg2[] : tensor<i64>
%1 = quantum.insert %arg0[%extracted_1], %out_qubits : !quantum.reg, !quantum.bit
return %1 : !quantum.reg
}
}
"""
if not has_jax:
return func
old_pjit = jax._src.pjit.jit_p # pylint: disable=protected-access
@wraps(func)
def inside(*args, **kwargs):
# Inside our "quantum subroutine", we want to be able to do normal jit on classical subroutines
# with the normal jit pipeline. Hence why it's patched back to the original function in inside
with Patcher(
(
jax._src.pjit, # pylint: disable=protected-access
"jit_p",
old_pjit,
),
):
return func(*args, **kwargs)
@wraps(inside)
def wrapper(*args, **kwargs):
if not enabled():
return func(*args, **kwargs)
# we want jit_p to be turned into quantum_subroutine_p just for the capturing of this particular
# function as a higher order primitive
with Patcher(
(
jax._src.pjit, # pylint: disable=protected-access
"jit_p",
quantum_subroutine_prim,
),
):
return jax.jit(
inside,
static_argnames=static_argnames,
static_argnums=static_argnums,
)(*args, **kwargs)
return wrapper
_modules/pennylane/capture/subroutine
Download Python script
Download Notebook
View on GitHub