Source code for pennylane._grad.jvp

# Copyright 2018-2021 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Defines qml.jvp
"""
from collections.abc import Sequence
from functools import lru_cache
from importlib.util import find_spec

from pennylane import capture
from pennylane.compiler import compiler
from pennylane.exceptions import CompileError

from .grad import _args_and_argnums, _setup_h, _setup_method

has_jax = find_spec("jax") is not None


def _get_shape(x):
    import jax  # pylint: disable=import-outside-toplevel

    return getattr(x, "shape", jax.numpy.shape(x))


# pylint: disable=unused-argument
@lru_cache
def _get_jvp_prim():
    if not has_jax:  # pragma: no cover
        return None

    import jax  # pylint: disable=import-outside-toplevel

    jvp_prim = capture.QmlPrimitive("jvp")
    jvp_prim.multiple_results = True
    jvp_prim.prim_type = "higher_order"

    @jvp_prim.def_impl
    def _jvp_impl(*args, jaxpr, fn, method, h, argnums):
        params = list(args[: len(jaxpr.invars)])
        dparams = list(args[len(jaxpr.invars) :])

        for i, p in enumerate(params):
            if i not in argnums:
                dparams.insert(i, 0 * p)

        def func(*inner_args):
            return jax.core.eval_jaxpr(jaxpr, [], *inner_args)

        results, dresults = jax.jvp(func, params, dparams)
        return (*results, *dresults)

    @jvp_prim.def_abstract_eval
    def _jvp_abstract_eval(*args, jaxpr, fn, method, h, argnums):
        return 2 * [v.aval for v in jaxpr.outvars]

    return jvp_prim


def _validate_tangents(params, dparams, argnums):
    from jax._src.api import _dtype  # pylint: disable=import-outside-toplevel

    if len(dparams) != len(argnums):
        raise TypeError(
            "number of tangents and number of differentiable parameters in qml.jvp do not "
            "match; the number of parameters must be equal. "
            f"Got {len(argnums)} differentiable parameters and so expected "
            f"as many tangents, but got {len(dparams)} instead."
        )

    for i, dx in zip(argnums, dparams):
        x = params[i]
        if _dtype(x) != _dtype(dx):
            raise TypeError(
                "function params and tangents arguments to qml.jvp do not match; "
                "dtypes must be equal. "
                f"Got function params dtype {_dtype(x)} and expected tangent dtype "
                f"to match, but got tangent dtype {_dtype(dx)} instead."
            )

        if _get_shape(x) != _get_shape(dx):
            raise ValueError(
                "qml.jvp called with different function params and tangent "
                f"shapes; got function params shape {_get_shape(x)} and tangent shape "
                f"{_get_shape(dx)}"
            )


# pylint: disable=too-many-arguments
def _capture_jvp(func, params, dparams, *, argnums=None, method=None, h=None):
    import jax  # pylint: disable=import-outside-toplevel
    from jax.tree_util import tree_leaves, tree_unflatten  # pylint: disable=import-outside-toplevel

    if not isinstance(params, Sequence):
        raise ValueError(f"params must be a Sequence in qml.jvp. Got type {type(params)}.")
    if not isinstance(dparams, Sequence):
        raise ValueError(f"tangents must be a Sequence in qml.jvp. Got type {type(params)}.")

    h = _setup_h(h)
    method = _setup_method(method)
    flat_args, flat_argnums, _, _ = _args_and_argnums(params, argnums)
    flat_dargs = tree_leaves(dparams)

    _validate_tangents(flat_args, flat_dargs, flat_argnums)

    flat_fn = capture.FlatFn(func)
    jaxpr = jax.make_jaxpr(flat_fn)(*params)
    j = jaxpr.jaxpr
    no_consts_jaxpr = j.replace(constvars=(), invars=j.constvars + j.invars)
    shifted_argnums = tuple(i + len(jaxpr.consts) for i in flat_argnums)

    prim_kwargs = {
        "fn": func,
        "method": method,
        "h": h,
        "argnums": shifted_argnums,
        "jaxpr": no_consts_jaxpr,
    }
    out_flat = _get_jvp_prim().bind(*jaxpr.consts, *flat_args, *flat_dargs, **prim_kwargs)
    flat_results, flat_dresults = out_flat[: len(j.outvars)], out_flat[len(j.outvars) :]

    results = tree_unflatten(flat_fn.out_tree, flat_results)
    dresults = tree_unflatten(flat_fn.out_tree, flat_dresults)
    return results, dresults


# pylint: disable=too-many-arguments, too-many-positional-arguments
[docs] def jvp(f, params, tangents, method=None, h=None, argnums=None): """A :func:`~.qjit` compatible Jacobian-vector product of PennyLane programs. This function allows the Jacobian-vector Product of a hybrid quantum-classical function to be computed within the compiled program. .. warning:: ``jvp`` is intended to be used with :func:`~.qjit` only. .. note:: When used with :func:`~.qjit`, this function only supports the Catalyst compiler; see :func:`catalyst.jvp` for more details. Please see the Catalyst :doc:`quickstart guide <catalyst:dev/quick_start>`, as well as the :doc:`sharp bits and debugging tips <catalyst:dev/sharp_bits>` page for an overview of the differences between Catalyst and PennyLane. Args: f (Callable): Function-like object to calculate JVP for params (List[Array]): List (or a tuple) of the function arguments specifying the point to calculate JVP at. A subset of these parameters are declared as differentiable by listing their indices in the ``argnums`` parameter. tangents(List[Array]): List (or a tuple) of tangent values to use in JVP. The list size and shapes must match the ones of differentiable params. method(str): Differentiation method to use, same as in :func:`~.grad`. h (float): the step-size value for the finite-difference (``"fd"``) method argnums (Union[int, List[int]]): the params' indices to differentiate. Returns: Tuple[Array]: Return values of ``f`` paired with the JVP values. Raises: TypeError: invalid parameter types ValueError: invalid parameter values .. seealso:: :func:`~.grad`, :func:`~.vjp`, :func:`~.jacobian` **Example 1 (basic usage)** .. code-block:: python @qml.qjit def jvp(params, tangent): def f(x): y = [jnp.sin(x[0]), x[1] ** 2, x[0] * x[1]] return jnp.stack(y) return qml.jvp(f, [params], [tangent]) >>> x = jnp.array([0.1, 0.2]) >>> tangent = jnp.array([0.3, 0.6]) >>> jvp(x, tangent) (Array([0.09983342, 0.04 , 0.02 ], dtype=float64), Array([0.29850125, 0.24 , 0.12 ], dtype=float64)) **Example 2 (argnums usage)** Here we show how to use ``argnums`` to ignore the non-differentiable parameter ``n`` of the target function. Note that the length and shapes of tangents must match the length and shape of primal parameters, which we mark as differentiable by passing their indices to ``argnums``. .. code-block:: python @qml.qjit @qml.qnode(qml.device("lightning.qubit", wires=2)) def circuit(n, params): qml.RX(params[n, 0], wires=n) qml.RY(params[n, 1], wires=n) return qml.expval(qml.Z(1)) @qml.qjit def workflow(primals, tangents): return qml.jvp(circuit, [1, primals], [tangents], argnums=[1]) >>> params = jnp.array([[0.54, 0.3154], [0.654, 0.123]]) >>> dy = jnp.array([[1.0, 1.0], [1.0, 1.0]]) >>> workflow(params, dy) (Array(0.78766064, dtype=float64), Array(-0.70114352, dtype=float64)) """ if capture.enabled(): return _capture_jvp(f, params, tangents, method=method, h=h, argnums=argnums) if active_jit := compiler.active_compiler(): available_eps = compiler.AvailableCompilers.names_entrypoints ops_loader = available_eps[active_jit]["ops"].load() return ops_loader.jvp(f, params, tangents, method=method, h=h, argnums=argnums) raise CompileError("PennyLane does not support the JVP function without QJIT.")