qml.capture.subroutine

subroutine(func, static_argnums=None, static_argnames=None)[source]

Denotes the creation of a function in the intermediate representation.

May be used to reduce compilation times. Instead of repeatedly compiling inlined versions of the function passed as a parameter, when functions are annotated with a subroutine, a single version of the function will be compiled and called from potentially multiple callsites.

Note

Subroutines are only available when using the PLxPR program capture interface.

Parameters:
  • subroutine (Callable) – the function

  • static_argnums (None | int | Sequence[int]) – the indices of the static arguments

  • static_argnames (None | str | Sequence[str]) – the names of static arguments. May be provided instead of static_argnums for readability.

Example

qml.capture.enable()

@qml.capture.subroutine
def f(x, wires):
    qml.RX(x, wires)

@qml.qnode(qml.device('lightning.qubit', wires=5))
def c(x : float):
    f(x, 0)
    f(x, 1)
    return qml.state()

print(jax.make_jaxpr(c)(0.5))
let f = { lambda ; a:f64[] b:i64[]. let
    _:AbstractOperator() = RX[n_wires=1] a b
in () } in
{ lambda ; c:f64[]. let
    d:c128[32] = qnode[
    device=<lightning.qubit device (wires=5) at 0x12aac1c40>
    execution_config=ExecutionConfig(grad_on_execution=False, use_device_gradient=None, use_device_jacobian_product=False, gradient_method='best', gradient_keyword_arguments={}, device_options={}, interface=<Interface.JAX: 'jax'>, derivative_order=1, mcm_config=MCMConfig(mcm_method=None, postselect_mode=None), convert_to_numpy=True, executor_backend=<class 'pennylane.concurrency.executors.native.multiproc.MPPoolExec'>)
    n_consts=0
    qfunc_jaxpr={ lambda ; e:f64[]. let
        quantum_subroutine_p[
            compiler_options_kvs=()
            ctx_mesh=Mesh(, axis_types=())
            donated_invars=(False, False)
            in_layouts=(None, None)
            in_shardings=(UnspecifiedValue, UnspecifiedValue)
            inline=False
            jaxpr=f
            keep_unused=False
            name=f
            out_layouts=()
            out_shardings=()
        ] e 0:i64[]
        quantum_subroutine_p[
            compiler_options_kvs=()
            ctx_mesh=Mesh(, axis_types=())
            donated_invars=(False, False)
            in_layouts=(None, None)
            in_shardings=(UnspecifiedValue, UnspecifiedValue)
            inline=False
            jaxpr=f
            keep_unused=False
            name=f
            out_layouts=()
            out_shardings=()
        ] e 1:i64[]
        g:AbstractMeasurement(n_wires=0) = state_wires
        in (g,) }
    qnode=<QNode: device='<lightning.qubit device (wires=5) at 0x12aac1c40>', interface='jax', diff_method='best', shots='Shots(total=None)'>
    shots_len=0
    ] c
in (d,)

If we create a qjit version of the QNode, we can inspect the mlir and see a FuncOp that is reused for both calls:

>>> qjit_c = qml.qjit(c)
>>> print(qjit_c.mlir[1010:1300]) 
%0 = quantum.alloc( 5) : !quantum.reg
%1 = call @f(%0, %arg0, %c_0) : (!quantum.reg, tensor<f64>, tensor<i64>) -> !quantum.reg
%2 = call @f(%1, %arg0, %c) : (!quantum.reg, tensor<f64>, tensor<i64>) -> !quantum.reg
%3 = quantum.compbasis qreg %2 : !quantum.obs
>>> print(qjit_c.mlir[1465:2070]) 
func.func private @f(%arg0: !quantum.reg, %arg1: tensor<f64>, %arg2: tensor<i64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage<internal>} {
    %extracted = tensor.extract %arg2[] : tensor<i64>
    %0 = quantum.extract %arg0[%extracted] : !quantum.reg -> !quantum.bit
    %extracted_0 = tensor.extract %arg1[] : tensor<f64>
    %out_qubits = quantum.custom "RX"(%extracted_0) %0 : !quantum.bit
    %extracted_1 = tensor.extract %arg2[] : tensor<i64>
    %1 = quantum.insert %arg0[%extracted_1], %out_qubits : !quantum.reg, !quantum.bit
    return %1 : !quantum.reg
    }
}

Contents

Using PennyLane

Release news

Development

API

Internals