qml.capture.subroutine¶
- subroutine(func, static_argnums=None, static_argnames=None)[source]¶
Denotes the creation of a function in the intermediate representation.
May be used to reduce compilation times. Instead of repeatedly compiling inlined versions of the function passed as a parameter, when functions are annotated with a subroutine, a single version of the function will be compiled and called from potentially multiple callsites.
Note
Subroutines are only available when using the PLxPR program capture interface.
- Parameters:
subroutine (Callable) – the function
static_argnums (None | int | Sequence[int]) – the indices of the static arguments
static_argnames (None | str | Sequence[str]) – the names of static arguments. May be provided instead of
static_argnumsfor readability.
Example
qml.capture.enable() @qml.capture.subroutine def f(x, wires): qml.RX(x, wires) @qml.qnode(qml.device('lightning.qubit', wires=5)) def c(x : float): f(x, 0) f(x, 1) return qml.state() print(jax.make_jaxpr(c)(0.5))
let f = { lambda ; a:f64[] b:i64[]. let _:AbstractOperator() = RX[n_wires=1] a b in () } in { lambda ; c:f64[]. let d:c128[32] = qnode[ device=<lightning.qubit device (wires=5) at 0x12aac1c40> execution_config=ExecutionConfig(grad_on_execution=False, use_device_gradient=None, use_device_jacobian_product=False, gradient_method='best', gradient_keyword_arguments={}, device_options={}, interface=<Interface.JAX: 'jax'>, derivative_order=1, mcm_config=MCMConfig(mcm_method=None, postselect_mode=None), convert_to_numpy=True, executor_backend=<class 'pennylane.concurrency.executors.native.multiproc.MPPoolExec'>) n_consts=0 qfunc_jaxpr={ lambda ; e:f64[]. let quantum_subroutine_p[ compiler_options_kvs=() ctx_mesh=Mesh(, axis_types=()) donated_invars=(False, False) in_layouts=(None, None) in_shardings=(UnspecifiedValue, UnspecifiedValue) inline=False jaxpr=f keep_unused=False name=f out_layouts=() out_shardings=() ] e 0:i64[] quantum_subroutine_p[ compiler_options_kvs=() ctx_mesh=Mesh(, axis_types=()) donated_invars=(False, False) in_layouts=(None, None) in_shardings=(UnspecifiedValue, UnspecifiedValue) inline=False jaxpr=f keep_unused=False name=f out_layouts=() out_shardings=() ] e 1:i64[] g:AbstractMeasurement(n_wires=0) = state_wires in (g,) } qnode=<QNode: device='<lightning.qubit device (wires=5) at 0x12aac1c40>', interface='jax', diff_method='best', shots='Shots(total=None)'> shots_len=0 ] c in (d,)If we create a
qjitversion of the QNode, we can inspect the mlir and see aFuncOpthat is reused for both calls:>>> qjit_c = qml.qjit(c) >>> print(qjit_c.mlir[1010:1300]) %0 = quantum.alloc( 5) : !quantum.reg %1 = call @f(%0, %arg0, %c_0) : (!quantum.reg, tensor<f64>, tensor<i64>) -> !quantum.reg %2 = call @f(%1, %arg0, %c) : (!quantum.reg, tensor<f64>, tensor<i64>) -> !quantum.reg %3 = quantum.compbasis qreg %2 : !quantum.obs >>> print(qjit_c.mlir[1465:2070]) func.func private @f(%arg0: !quantum.reg, %arg1: tensor<f64>, %arg2: tensor<i64>) -> !quantum.reg attributes {llvm.linkage = #llvm.linkage<internal>} { %extracted = tensor.extract %arg2[] : tensor<i64> %0 = quantum.extract %arg0[%extracted] : !quantum.reg -> !quantum.bit %extracted_0 = tensor.extract %arg1[] : tensor<f64> %out_qubits = quantum.custom "RX"(%extracted_0) %0 : !quantum.bit %extracted_1 = tensor.extract %arg2[] : tensor<i64> %1 = quantum.insert %arg0[%extracted_1], %out_qubits : !quantum.reg, !quantum.bit return %1 : !quantum.reg } }