Source code for pennylane.transforms.core.compile_pipeline

# Copyright 2023 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains the ``CompilePipeline`` class.
"""

from __future__ import annotations

import warnings
from collections import defaultdict
from collections.abc import Sequence
from copy import copy
from enum import StrEnum
from functools import partial
from typing import TYPE_CHECKING, Any, overload

from pennylane.exceptions import TransformError
from pennylane.tape import QuantumScript, QuantumScriptBatch
from pennylane.typing import BatchPostprocessingFn, PostprocessingFn, ResultBatch

from .cotransform_cache import CotransformCache
from .transform import BoundTransform, Transform

if TYPE_CHECKING:
    import jax

    from pennylane.workflow import QNode


class ProtectedLevel(StrEnum):
    """Enum for the protected levels of inspection."""

    TOP = "top"
    USER = "user"
    GRADIENT = "gradient"
    DEVICE = "device"
    ALL = "all"
    ALL_MLIR = "all-mlir"


def _batch_postprocessing(
    results: ResultBatch,
    individual_fns: list[PostprocessingFn],
    slices: list[slice] | list[int],
) -> ResultBatch:
    """Broadcast individual post processing functions onto their respective tapes.

    Args:
        results (ResultBatch): The numeric outcome from executing a batch of :class:`~.QuantumTape`

    Keyword Args:
        individual_fns (List[Callable]): postprocessing functions converting a batch of results into a single result
            corresponding to only a single :class:`~.QuantumTape`.
        slices (List[slice]): the indices for the results that correspond to each individual post processing function.

    >>> results = (1.0, 2.0, 3.0, 4.0)
    >>> def postprocessing1(results):
    ...     return results[0] + results[1]
    >>> def postprocessing2(results):
    ...     return results[0]+0.5
    >>> def postprocessing3(results):
    ...     return results[0]*2
    >>> slices = [slice(0,2), slice(2,3), slice(3,4)]
    >>> individual_fns = [postprocessing1, postprocessing2, postprocessing3]
    >>> _batch_postprocessing(results, individual_fns, slices)
    (3.0, 3.5, 8.0)

    """
    return tuple(fn(results[sl]) for fn, sl in zip(individual_fns, slices, strict=True))


def _apply_postprocessing_stack(
    results: ResultBatch,
    postprocessing_stack: list[BatchPostprocessingFn],
) -> ResultBatch:
    """Applies the postprocessing and cotransform postprocessing functions in a Last-In-First-Out LIFO manner.

    Args:
        results (ResultBatch): The numeric outcome from executing a batch of :class:`~.QuantumTape`

    Keyword Args:
        postprocessing_stack (List(BatchPostProcessingFn)): a LIFO stack of post processing functions.

    Returns:
        ResultBatch: the post processed results.

    >>> results = (1.0, 2.0, 3.0, 4.0)
    >>> def postprocessing1(results):
    ...     return (results[0] + results[1], results[2] + results[3])
    >>> def postprocessing2(results):
    ...     return (results[0] + 1, results[1] + 2)
    >>> _apply_postprocessing_stack(results, [postprocessing1])
    (3.0, 7.0)
    >>> _apply_postprocessing_stack(results, [postprocessing2, postprocessing1])
    (4.0, 9.0)

    """
    for postprocessing in reversed(postprocessing_stack):
        results = postprocessing(results)
    return results


def null_postprocessing(results: ResultBatch) -> ResultBatch:
    """An empty postprocessing function that simply returns its input.

    Args:
        results (ResultBatch): Results from executing a batch of :class:`~.QuantumTape`.

    Returns:
        ResultBatch: the input to the function.

    """
    return results


[docs] class CompilePipeline: """A sequence of transforms to be applied to a quantum function or a :class:`~pennylane.QNode`. Args: *transforms (Optional[Sequence[Transform | BoundTransform]]): A sequence of transforms with which to initialize the program. cotransform_cache (Optional[CotransformCache]): A named tuple containing the ``qnode``, ``args``, and ``kwargs`` required to compute classical cotransforms. **Example:** The ``CompilePipeline`` class allows you to chain together multiple quantum function transforms to create custom circuit optimization pipelines. For example, consider if you wanted to apply the following optimizations to a quantum circuit: - pushing all commuting single-qubit gates as far right as possible (:func:`~pennylane.transforms.commute_controlled`) - cancellation of adjacent inverse gates (:func:`~pennylane.transforms.cancel_inverses`) - merging adjacent rotations of the same type (:func:`~pennylane.transforms.merge_rotations`) You can specify a transform program (``pipeline``) by passing these transforms to the ``CompilePipeline`` class. By applying the created ``pipeline`` directly on a quantum function as a decorator, the circuit will be transformed with each pass within the pipeline sequentially: .. code-block:: python pipeline = qml.CompilePipeline( qml.transforms.commute_controlled, qml.transforms.cancel_inverses(recursive=True), qml.transforms.merge_rotations, ) # Add a marker for inspectibility pipeline.add_marker("no-transforms", 0) @pipeline @qml.qnode(qml.device("default.qubit")) def circuit(x, y): qml.CNOT([1, 0]) qml.X(0) qml.CNOT([1, 0]) qml.H(0) qml.H(0) qml.X(0) qml.RX(x, wires=0) qml.RX(y, wires=0) return qml.expval(qml.Z(1)) >>> print(circuit.compile_pipeline) CompilePipeline( ├─▶ no-transforms [1] commute_controlled(), [2] cancel_inverses(recursive=True), [3] merge_rotations() ) >>> print(qml.draw(circuit, level="no-transforms")(0.1, 0.2)) # or level=0 0: ─╭X──X─╭X──H──H──X──RX(0.10)──RX(0.20)─┤ 1: ─╰●────╰●──────────────────────────────┤ <Z> >>> print(qml.draw(circuit)(0.1, 0.2)) 0: ──RX(0.30)─┤ 1: ───────────┤ <Z> .. details:: :title: Manipulating Compilation Pipelines Alternatively, the compilation pipeline can be constructed intuitively by combining multiple transforms. For example, the transforms can be added together with ``+``: >>> pipeline = qml.transforms.merge_rotations + qml.transforms.cancel_inverses(recursive=True) >>> print(pipeline) CompilePipeline( [1] merge_rotations(), [2] cancel_inverses(recursive=True) ) Or multiplied by a scalar via ``*``: >>> pipeline += 2 * qml.transforms.commute_controlled >>> print(pipeline) CompilePipeline( [1] merge_rotations(), [2] cancel_inverses(recursive=True), [3] commute_controlled(), [4] commute_controlled() ) A compilation pipeline can also be easily modified using operations similar to Python lists, including ``insert``, ``append``, ``extend`` and ``pop``: >>> pipeline.insert(0, qml.transforms.remove_barrier) >>> print(pipeline) CompilePipeline( [1] remove_barrier(), [2] merge_rotations(), [3] cancel_inverses(recursive=True), [4] commute_controlled(), [5] commute_controlled() ) Additionally, multiple compilation pipelines can be concatenated: >>> another_pipeline = qml.decompose(gate_set={qml.RX, qml.RZ, qml.CNOT}) + qml.transforms.combine_global_phases >>> print(another_pipeline + pipeline) CompilePipeline( [1] decompose(gate_set=...), [2] combine_global_phases(), [3] remove_barrier(), [4] merge_rotations(), [5] cancel_inverses(recursive=True), [6] commute_controlled(), [7] commute_controlled() ) We can create a new pipeline that will do multiple passes of the original with multiplication: >>> original = qml.transforms.merge_rotations + qml.transforms.cancel_inverses >>> print(2 * original) CompilePipeline( [1] merge_rotations(), [2] cancel_inverses(), [3] merge_rotations(), [4] cancel_inverses() ) .. details:: :title: Inspecting and Marking Let's create a simple pipeline to inspect, >>> pipeline = qml.transforms.commute_controlled + qml.transforms.cancel_inverses + qml.transforms.merge_rotations We can inspect the original pipeline by simply printing it, >>> print(pipeline) CompilePipeline( [1] commute_controlled(), [2] cancel_inverses(), [3] merge_rotations() ) We can add markers (that act as checkpoints) in the pipeline to mark important positions, >>> pipeline.add_marker("final-transform") >>> print(pipeline) CompilePipeline( [1] commute_controlled(), [2] cancel_inverses(), [3] merge_rotations() └─▶ final-transform ) >>> pipeline.add_marker("after-commute-controlled", 1) >>> print(pipeline) CompilePipeline( [1] commute_controlled(), ├─▶ after-commute-controlled [2] cancel_inverses(), [3] merge_rotations() └─▶ final-transform ) Two different markers can be used to mark the same level, causing them to stack, >>> pipeline.add_marker("after-merge-rotations") >>> print(pipeline) CompilePipeline( [1] commute_controlled(), ├─▶ after-commute-controlled [2] cancel_inverses(), [3] merge_rotations() └─▶ final-transform, after-merge-rotations ) A marker's level (the index of the transform it follows) can be retrieved with, >>> print(pipeline.get_marker_level("final-transform")) 3 >>> print(pipeline.get_marker_level("after-merge-rotations")) 3 We can remove a ``marker`` using the ``remove_marker`` method, >>> pipeline.remove_marker("final-transform") >>> pipeline.markers ['after-commute-controlled', 'after-merge-rotations'] The pipeline structure and marker placement are represented as follows, .. code-block:: CompilePipeline( ├─▶ markers for level 0 (no transforms) [1] transform_1(), ├─▶ markers for level 1 (after 1st transform) [2] transform_2(), ... [n] transform_n() └─▶ markers for level n (after nth transform) ) Importantly, markers are correctly maintained after pipeline manipulations, >>> pipeline * 2 # Markers are not duplicated CompilePipeline( [1] <commute_controlled()>, ├─▶ after-commute-controlled [2] <cancel_inverses()>, [3] <merge_rotations()>, ├─▶ after-merge-rotations [4] <commute_controlled()>, [5] <cancel_inverses()>, [6] <merge_rotations()> ) >>> pipeline + qml.transforms.undo_swaps CompilePipeline( [1] <commute_controlled()>, ├─▶ after-commute-controlled [2] <cancel_inverses()>, [3] <merge_rotations()>, ├─▶ after-merge-rotations [4] <undo_swaps()> ) >>> pipeline.pop() <merge_rotations()> >>> print(pipeline) CompilePipeline( [1] commute_controlled(), ├─▶ after-commute-controlled [2] cancel_inverses() └─▶ after-merge-rotations ) >>> pipeline[1:] # Get everything after the second transform CompilePipeline( ├─▶ after-commute-controlled [1] <cancel_inverses()> └─▶ after-merge-rotations ) """ @overload def __init__( self, transforms: Sequence[Transform | BoundTransform], /, *, cotransform_cache: CotransformCache | None = None, ): ... @overload def __init__( self, *transforms: CompilePipeline | BoundTransform | Transform, cotransform_cache: CotransformCache | None = None, ): ... def __init__( self, *transforms: CompilePipeline | BoundTransform | Transform | Sequence[Transform | BoundTransform], cotransform_cache: CotransformCache | None = None, ): if len(transforms) == 1 and isinstance(transforms[0], Sequence): transforms = list(transforms[0]) # If all elements are BoundTransform, store directly (already expanded) if all(isinstance(t, BoundTransform) for t in transforms): self._compile_pipeline = transforms self._markers: dict[str, int] = {} self.cotransform_cache = cotransform_cache return self._compile_pipeline = [] self._markers: dict[str, int] = {} self.cotransform_cache = cotransform_cache for obj in transforms: if not isinstance(obj, (CompilePipeline, BoundTransform, Transform)): raise TypeError( "CompilePipeline can only be constructed with a series of transforms " "or compile pipelines, or with a single list of transforms." ) self += obj def __copy__(self): new_pipeline = CompilePipeline( self._compile_pipeline[:], cotransform_cache=self.cotransform_cache ) new_pipeline._markers = self._markers.copy() return new_pipeline def __iter__(self): """list[BoundTransform]: Return an iterator to the underlying compile pipeline.""" return self._compile_pipeline.__iter__() def __len__(self) -> int: """int: Return the number transforms in the program.""" return len(self._compile_pipeline) @overload def __getitem__(self, idx: int) -> BoundTransform: ... @overload def __getitem__(self, idx: slice) -> CompilePipeline: ... def __getitem__(self, idx): """(BoundTransform, List[BoundTransform]): Return the indexed transform container from underlying compile pipeline""" if isinstance(idx, slice): start, stop, step = idx.indices(len(self._compile_pipeline)) new_markers = {} if step == 1: # NOTE: Include final marker if the slices reaches the end # Otherwise, don't include it. slice_reaches_end = stop == len(self._compile_pipeline) upper = stop + 1 if slice_reaches_end else stop new_markers = {k: v - start for k, v in self._markers.items() if start <= v < upper} elif self._markers: warnings.warn( "Slicing a CompilePipeline that contains markers with a step size other than 1 is not supported. " "Markers have been dropped from the result. ", UserWarning, ) compile_pipeline = CompilePipeline(self._compile_pipeline[idx]) compile_pipeline._markers = new_markers return compile_pipeline return self._compile_pipeline[idx] def __bool__(self) -> bool: return bool(self._compile_pipeline) def __add__(self, other: CompilePipeline | BoundTransform | Transform) -> CompilePipeline: # Convert dispatcher to container if needed if isinstance(other, Transform): other = BoundTransform(other) # Handle BoundTransform if isinstance(other, BoundTransform): transforms = [other] if expand_transform := other.expand_transform: transforms.insert(0, expand_transform) other = CompilePipeline(transforms) # Handle CompilePipeline if isinstance(other, CompilePipeline): offset = len(self._compile_pipeline) new_markers = self._markers.copy() for name, pos in other._markers.items(): new_markers[name] = pos + offset if self.has_final_transform and other.has_final_transform: raise TransformError("The compile pipeline already has a terminal transform.") transforms = self._compile_pipeline[:] transforms.extend(other._compile_pipeline) cotransform_cache = None if self.cotransform_cache: if other.cotransform_cache: raise ValueError("Cannot add two compile pipelines with cotransform caches.") cotransform_cache = self.cotransform_cache elif other.cotransform_cache: cotransform_cache = other.cotransform_cache new_pipeline = CompilePipeline(transforms, cotransform_cache=cotransform_cache) new_pipeline._markers = new_markers return new_pipeline return NotImplemented def __radd__(self, other: BoundTransform | Transform) -> CompilePipeline: """Right addition to prepend a transform to the program. Args: other: A BoundTransform or Transform to prepend. Returns: CompilePipeline: A new program with the transform prepended. """ if isinstance(other, Transform): other = BoundTransform(other) if not isinstance(other, BoundTransform): return NotImplemented if self.has_final_transform and other.is_final_transform: raise TransformError("The compile pipeline already has a terminal transform.") transforms = [other] if expand_transform := other.expand_transform: transforms.insert(0, expand_transform) transforms += self._compile_pipeline return CompilePipeline(transforms, cotransform_cache=self.cotransform_cache) def __iadd__(self, other: CompilePipeline | BoundTransform | Transform) -> CompilePipeline: """In-place addition to append a transform to the program. Args: other: A BoundTransform, Transform, or CompilePipeline to append. Returns: CompilePipeline: This program with the transform(s) appended. """ # Convert dispatcher to container if needed if isinstance(other, Transform): other = BoundTransform(other) if isinstance(other, BoundTransform): transforms = [other] if expand_transform := other.expand_transform: transforms.insert(0, expand_transform) other = CompilePipeline(transforms) if isinstance(other, CompilePipeline): offset = len(self._compile_pipeline) for name, pos in other._markers.items(): self._markers[name] = pos + offset if self.has_final_transform and other.has_final_transform: raise TransformError("The compile pipeline already has a terminal transform.") self._compile_pipeline.extend(other._compile_pipeline) if other.cotransform_cache: if self.cotransform_cache: raise ValueError("Cannot add two compile pipelines with cotransform caches.") self.cotransform_cache = other.cotransform_cache return self return NotImplemented def __mul__(self, n: int) -> CompilePipeline: """Right multiplication to repeat a program n times. Args: n (int): Number of times to repeat this program. Returns: CompilePipeline: A new program with this program repeated n times. """ if not isinstance(n, int): return NotImplemented if n < 0: raise ValueError("Cannot multiply compile pipeline by negative integer") if self.has_final_transform: raise TransformError( "Cannot multiply a compile pipeline that has a terminal transform." ) transforms = self._compile_pipeline * n new_pipeline = CompilePipeline(transforms, cotransform_cache=self.cotransform_cache) new_pipeline._markers = self._markers.copy() return new_pipeline __rmul__ = __mul__ def _ipython_display_(self): """Displays __str__ in ipython instead of __repr__""" # See https://ipython.readthedocs.io/en/stable/config/integrating.html#custom-methods print(str(self)) def __str__(self) -> str: """Returns a user friendly representation of the compile pipeline.""" def truncate(val: Any) -> str: _CHAR_THRESHOLD = 50 val_str = str(val) return "..." if len(val_str) > _CHAR_THRESHOLD else val_str lines = [] inv_marker_map = defaultdict(list) for label, index in self._markers.items(): inv_marker_map[index].append(label) if marker := inv_marker_map.get(0): symbol = "├─▶" if self else "└─▶" lines.append(f" {symbol} {', '.join(marker)}") for i, transform in enumerate(self): args_str = ", ".join(truncate(a) for a in transform.args) kwargs_str = ", ".join(f"{k}={truncate(v)}" for k, v in transform.kwargs.items()) sep = ", " if args_str and kwargs_str else "" transform_str = f"{transform.tape_transform.__name__}({args_str}{sep}{kwargs_str})" lines.append(f" [{i + 1}] {transform_str}" + "," * bool(i != len(self) - 1)) if marker := inv_marker_map.get(i + 1): symbol = "├─▶" if i != len(self) - 1 else "└─▶" lines.append(f" {symbol} {', '.join(marker)}") if lines: contents = "\n".join(lines) return f"CompilePipeline(\n{contents}\n)" return "CompilePipeline()" def __repr__(self) -> str: """The detailed string representation of the compile pipeline.""" lines = [] inv_marker_map = defaultdict(list) for label, index in self._markers.items(): inv_marker_map[index].append(label) if marker := inv_marker_map.get(0): symbol = "├─▶" if self else "└─▶" lines.append(f" {symbol} {', '.join(marker)}") for i, transform in enumerate(self): lines.append(f" [{i + 1}] {repr(transform)}" + "," * bool(i != len(self) - 1)) if marker := inv_marker_map.get(i + 1): symbol = "├─▶" if i != len(self) - 1 else "└─▶" lines.append(f" {symbol} {', '.join(marker)}") if lines: contents = "\n".join(lines) return f"CompilePipeline(\n{contents}\n)" return "CompilePipeline()" def __eq__(self, other) -> bool: if not isinstance(other, CompilePipeline): return False return self._compile_pipeline == other._compile_pipeline and ( self._markers == other._markers ) def __contains__(self, obj) -> bool: if isinstance(obj, BoundTransform): return obj in self._compile_pipeline if isinstance(obj, Transform): return any(obj.tape_transform == t.tape_transform for t in self) return False
[docs] def remove(self, obj: BoundTransform | Transform): """In place remove the input containers, specifically, 1. if the input is a Transform, remove all containers matching the transform; 2. if the input is a BoundTransform, remove all containers exactly matching the input. Args: obj (BoundTransform or Transform): The object to remove from the program. """ if not isinstance(obj, (Transform, BoundTransform)): raise TypeError("Only BoundTransform or Transform can be removed.") i = len(self) - 1 while i >= 0: # pylint: disable=protected-access if (isinstance(obj, Transform) and obj == self[i]._transform) or self[i] == obj: removed = self._compile_pipeline.pop(i) self._markers = {k: (v - 1 if v > i else v) for k, v in self._markers.items()} # Remove the associated expand_transform if present if i > 0 and removed.expand_transform == self[i - 1]: self._compile_pipeline.pop(i - 1) self._markers = { k: (v - 1 if v > i - 1 else v) for k, v in self._markers.items() } i -= 1 i -= 1
[docs] def append(self, transform: BoundTransform | Transform): """Add a transform to the end of the program. Args: transform (Transform or BoundTransform): A transform represented by its container. """ if isinstance(transform, (list, tuple, CompilePipeline)): raise TypeError( "append() expects a single transform, not a sequence. " "Use extend() to add multiple transforms at once." ) if not isinstance(transform, BoundTransform): transform = BoundTransform(transform) # Program can only contain one informative transform and at the end of the program if self.has_final_transform and transform.is_final_transform: raise TransformError("The compile pipeline already has a terminal transform.") if expand_transform := transform.expand_transform: self._compile_pipeline.append(expand_transform) self._compile_pipeline.append(transform)
[docs] def extend(self, transforms: CompilePipeline | Sequence[BoundTransform | Transform]): """Extend the pipeline by appending transforms from an iterable. Args: transforms (CompilePipeline, or Sequence[BoundTransform | Transform]): A CompilePipeline or an iterable of transforms to append. """ # Handle CompilePipeline by using __iadd__ which already handles this case if isinstance(transforms, CompilePipeline): self += transforms return # Handle iterables (list, tuple, etc.) for t in transforms: self += t
[docs] def remove_marker(self, label: str) -> None: """Remove the marker corresponding to the label.""" if label not in self._markers: raise ValueError(f"No marker found with label '{label}'.") del self._markers[label]
[docs] def add_marker(self, label: str, level: int | None = None) -> None: """Add a marker to the compilation pipeline at a given level. Args: label (str): The label for the marker. level (int | None): The level position for the marker. If ``None``, the marker will be append to the end of the compilation pipeline. Raises: ValueError: If the label corresponds to a protected level, if the label is not unique, or if the level is out of bounds. **Example:** >>> pipeline = CompilePipeline() >>> pipeline += qml.transforms.merge_rotations >>> pipeline.add_marker("after-merge-rotations") >>> print(pipeline) CompilePipeline( [1] merge_rotations() └─▶ after-merge-rotations ) >>> pipeline.add_marker("no-transforms", 0) >>> print(pipeline) CompilePipeline( ├─▶ no-transforms [1] merge_rotations() └─▶ after-merge-rotations ) """ if not isinstance(label, str): raise ValueError("Marker label must be a string.") if label in (protected_levels := [level.value for level in ProtectedLevel]): raise ValueError( f"Found marker for protected level '{label}'. Protected options are {', '.join(protected_levels)}." ) if label in self._markers: raise ValueError(f"Found multiple markers for level '{label}'. Markers must be unique.") if level is None: level = len(self._compile_pipeline) elif not isinstance(level, int) or (level < 0 or level > len(self._compile_pipeline)): raise ValueError( f"Marker level must be an integer between 0 and the number of transforms in the pipeline ({len(self._compile_pipeline)}), inclusive." ) self._markers[label] = level
@property def markers(self) -> list[str]: """Retrieve list of markers present in the compiliation pipeline.""" return list(self._markers.keys())
[docs] def get_marker_level(self, label: str) -> int | None: """Retrieve the level corresponding to a marker label.""" return self._markers.get(label)
[docs] def add_transform(self, transform: Transform, *targs, **tkwargs): """Add a transform to the end of the program. Note that this should be a function decorated with/called by ``qml.transform``, and not a ``BoundTransform``. Args: transform (Transform): The transform to add to the compile pipeline. *targs: Any additional arguments that are passed to the transform. Keyword Args: **tkwargs: Any additional keyword arguments that are passed to the transform. """ if not isinstance(transform, Transform): raise TransformError("Only transforms can be added to the compile pipeline.") self.append(BoundTransform(transform, args=targs, kwargs=tkwargs))
[docs] def insert(self, index: int, transform: Transform | BoundTransform): """Insert a transform at a given index. Args: index (int): The index to insert the transform. transform (transform or BoundTransform): the transform to insert """ self._markers = {k: (v + 1 if v >= index else v) for k, v in self._markers.items()} if not isinstance(transform, BoundTransform): transform = BoundTransform(transform) # Program can only contain one informative transform and at the end of the program if self and transform.is_final_transform: raise TransformError("Terminal transform can only be added to the end of the pipeline.") self._compile_pipeline.insert(index, transform) if expand_transform := transform.expand_transform: self._compile_pipeline.insert(index, expand_transform)
[docs] def pop(self, index: int = -1) -> BoundTransform: """Pop the transform container at a given index of the program. Args: index (int): the index of the transform to remove. Returns: BoundTransform: The removed transform. """ transform = self._compile_pipeline.pop(index) # Adjust index to be non-negative index = index if index >= 0 else len(self) + index + 1 # Decrement relevant markers self._markers = {k: (v - 1 if v > index else v) for k, v in self._markers.items()} if index > 0 and transform.expand_transform == self._compile_pipeline[index - 1]: self._compile_pipeline.pop(index - 1) # Decrement relevant markers self._markers = {k: (v - 1 if v > index - 1 else v) for k, v in self._markers.items()} return transform
@property def is_informative(self) -> bool: """``True`` if the compile pipeline is informative. Returns: bool: Boolean """ return any(t.is_informative for t in self) @property def has_final_transform(self) -> bool: """``True`` if the compile pipeline has a terminal transform.""" return any(t.is_final_transform for t in self)
[docs] def has_classical_cotransform(self) -> bool: """Check if the compile pipeline has some classical cotransforms. Returns: bool: Boolean """ return any(t.classical_cotransform is not None for t in self)
[docs] def set_classical_component(self, qnode, args, kwargs): """Set the classical jacobians and argnums if the transform is hybrid with a classical cotransform.""" # pylint: disable=no-member if self.has_classical_cotransform() and self[-1].kwargs.get("hybrid", True): self.cotransform_cache = CotransformCache(qnode, args, kwargs)
def __call_tapes( self, tapes: QuantumScript | QuantumScriptBatch ) -> tuple[QuantumScriptBatch, BatchPostprocessingFn]: if not self: return tapes, null_postprocessing if isinstance(tapes, QuantumScript): tapes = (tapes,) processing_fns_stack = [] for bound_transform in self: transform, targs, tkwargs, cotransform, _, _, _ = bound_transform tkwargs = { key: value for key, value in tkwargs.items() if key not in {"argnums", "hybrid"} } execution_tapes, fns, slices, classical_fns = [], [], [], [] start = 0 argnums = ( self.cotransform_cache.get_argnums(bound_transform) if self.cotransform_cache else None ) classical_jacobians = [] for tape_idx, tape in enumerate(tapes): if argnums is not None: tape.trainable_params = argnums[tape_idx] if transform is None: raise NotImplementedError( f"transform {bound_transform} has no defined tape transform." ) new_tapes, fn = transform(tape, *targs, **tkwargs) execution_tapes.extend(new_tapes) fns.append(fn) end = start + len(new_tapes) slices.append(slice(start, end)) start = end jac = ( self.cotransform_cache.get_classical_jacobian(bound_transform, tape_idx) if self.cotransform_cache else None ) classical_jacobians.append(jac) if cotransform and classical_jacobians[-1] is not None: classical_fns.append( partial(cotransform, cjac=classical_jacobians[-1], tape=tape) ) if cotransform and classical_fns: slices_classical = list(range(len(tapes))) batch_postprocessing_classical = partial( _batch_postprocessing, individual_fns=classical_fns, slices=slices_classical ) batch_postprocessing_classical.__doc__ = _batch_postprocessing.__doc__ processing_fns_stack.append(batch_postprocessing_classical) batch_postprocessing = partial(_batch_postprocessing, individual_fns=fns, slices=slices) batch_postprocessing.__doc__ = _batch_postprocessing.__doc__ processing_fns_stack.append(batch_postprocessing) # set input tapes for next iteration. tapes = execution_tapes postprocessing_fn = partial( _apply_postprocessing_stack, postprocessing_stack=processing_fns_stack, ) postprocessing_fn.__doc__ = _apply_postprocessing_stack.__doc__ # Reset classical jacobians return tuple(tapes), postprocessing_fn def __call_jaxpr( self, jaxpr: jax.extend.core.Jaxpr, consts: Sequence, *args ) -> jax.extend.core.ClosedJaxpr: # pylint: disable=import-outside-toplevel import jax cur_jaxpr = jax.extend.core.ClosedJaxpr(jaxpr, consts) for container in self: _, targs, tkwargs, _, plxpr_transform, _, _ = container cur_jaxpr = plxpr_transform(cur_jaxpr.jaxpr, cur_jaxpr.consts, targs, tkwargs, *args) return cur_jaxpr def __call_generic(self, obj): """Apply the transform program to a generic object (QNode, device, callable, etc.). This method chain-applies each transform using the generic dispatch system. Args: obj: The object to transform (QNode, device, callable, etc.). Returns: The transformed object. """ result = obj for container in self: result = container(result) # Merge markers, don't overwrite. # NOTE: Duck type a QNode object so we don't have to import it here if hasattr(result, "compile_pipeline"): offset = len(result.compile_pipeline) - len(self) # pylint: disable=protected-access for name, pos in self._markers.items(): result.compile_pipeline._markers[name] = pos + offset return result @overload def __call__( self, jaxpr: jax.extend.core.Jaxpr, consts: Sequence, *args ) -> jax.extend.core.ClosedJaxpr: ... @overload def __call__(self, qnode: QNode, *args, **kwargs) -> QNode: ... @overload def __call__( self, tape: QuantumScript | QuantumScriptBatch ) -> tuple[QuantumScriptBatch, BatchPostprocessingFn]: ... def __call__(self, *args, **kwargs): if type(args[0]).__name__ == "Jaxpr": return self.__call_jaxpr(*args, **kwargs) first_arg = args[0] # Sequence of QuantumScripts: QuantumScriptBatch if isinstance(first_arg, (QuantumScript, Sequence)): return self.__call_tapes(*args, **kwargs) # For any other object (QNode, device, callable, etc.), # chain-apply each transform using the generic dispatch system return self.__call_generic(first_arg)
@Transform.generic_register def _apply_to_program(obj: CompilePipeline, transform, *targs, **tkwargs): targs, tkwargs = transform.setup_inputs(*targs, **tkwargs) program = copy(obj) program.append(BoundTransform(transform, args=targs, kwargs=tkwargs)) return program TransformProgram = CompilePipeline