qml.vjp

vjp(f, params, cotangents, method=None, h=None, argnums=None)[source]

A qjit() compatible Vector-Jacobian product of PennyLane programs.

This function allows the Vector-Jacobian Product of a hybrid quantum-classical function to be computed within the compiled program.

Warning

vjp is intended to be used with qjit() only.

Note

When used with qjit(), this function only supports the Catalyst compiler. See catalyst.vjp() for more details.

Please see the Catalyst quickstart guide, as well as the sharp bits and debugging tips page for an overview of the differences between Catalyst and PennyLane.

Parameters:
  • f (Callable) – Function-like object to calculate VJP for

  • params (Sequence[Pytree[Array]]) – List (or a tuple) of arguments for f specifying the point to calculate VJP at. A subset of these parameters are declared as differentiable by listing their indices in the argnums parameter.

  • cotangents (Pytree[Array]) – Cotangent values to use in VJP. Should match the pytree structure of the functions output.

  • method (str) – Differentiation method to use, same as in grad().

  • h (float) – the step-size value for the finite-difference ("fd") method

  • argnums (Union[int, List[int]]) – the params’ indices to differentiate.

Returns:

Return values of f paired with the VJP values.

Return type:

Tuple[Array]

See also

grad(), jvp(), jacobian()

Note

While jax.vjp has no argnums and treats all params as trainable as default, we default to only the first argument as trainable by default.

Example

@qml.qjit(static_argnames="argnums")
def calculate_vjp_qjit(x, y, cotangent, argnums):
  def f(x, y):
      return x * y

  return qml.vjp(f, (x, y), cotangent, argnums=argnums)
>>> params = (jnp.array([1.0, 2.0]), jnp.array([2.0, 3.0]))
>>> dy = jnp.array([10.0, 20.0])
>>> results, dparams = calculate_vjp_qjit(*params, dy, 0)
>>> results
Array([2., 6.], dtype=float64)
>>> dparams 
Array([20., 60.], dtype=float64)

Similar to grad and jacobian, if argnums is an array, the dparams gains an additional dimension that is squeezed out when argnums is an integer:

>>> calculate_vjp_qjit(*params, dy, (0,))[1]
(Array([20., 60.], dtype=float64),)