Source code for pennylane.drawer.draw

# pylint: disable=too-many-arguments

# Copyright 2018-2021 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains the drawing function.
"""
from functools import wraps
from importlib.metadata import distribution
import warnings

import pennylane as qml
from .tape_mpl import tape_mpl
from .tape_text import tape_text


def catalyst_qjit(qnode):
    """The ``catalyst.while`` wrapper method"""
    try:
        distribution("pennylane_catalyst")
        return qnode.__class__.__name__ == "QJIT"
    except ImportError:
        return False


[docs]def draw( qnode, wire_order=None, show_all_wires=False, decimals=2, max_length=100, show_matrices=True, expansion_strategy=None, ): """Create a function that draws the given qnode or quantum function. Args: qnode (.QNode or Callable): the input QNode or quantum function that is to be drawn. wire_order (Sequence[Any]): the order (from top to bottom) to print the wires of the circuit. If not provided, the wire order defaults to the device wires. If device wires are not available, the circuit wires are sorted if possible. show_all_wires (bool): If True, all wires, including empty wires, are printed. decimals (int): How many decimal points to include when formatting operation parameters. ``None`` will omit parameters from operation labels. max_length (int): Maximum string width (columns) when printing the circuit show_matrices=False (bool): show matrix valued parameters below all circuit diagrams expansion_strategy (str): The strategy to use when circuit expansions or decompositions are required. Note that this is ignored if the input is not a QNode. - ``gradient``: The QNode will attempt to decompose the internal circuit such that all circuit operations are supported by the gradient method. - ``device``: The QNode will attempt to decompose the internal circuit such that all circuit operations are natively supported by the device. Returns: A function that has the same argument signature as ``qnode``. When called, the function will draw the QNode/qfunc. **Example** .. code-block:: python3 @qml.qnode(qml.device('lightning.qubit', wires=2)) def circuit(a, w): qml.Hadamard(0) qml.CRX(a, wires=[0, 1]) qml.Rot(*w, wires=[1], id="arbitrary") qml.CRX(-a, wires=[0, 1]) return qml.expval(qml.Z(0) @ qml.Z(1)) >>> print(qml.draw(circuit)(a=2.3, w=[1.2, 3.2, 0.7])) 0: ──H─╭●─────────────────────────────────────────╭●─────────┤ ╭<Z@Z> 1: ────╰RX(2.30)──Rot(1.20,3.20,0.70,"arbitrary")─╰RX(-2.30)─┤ ╰<Z@Z> .. details:: :title: Usage Details By specifying the ``decimals`` keyword, parameters are displayed to the specified precision. >>> print(qml.draw(circuit, decimals=4)(a=2.3, w=[1.2, 3.2, 0.7])) 0: ──H─╭●─────────────────────────────────────────────────╭●───────────┤ ╭<Z@Z> 1: ────╰RX(2.3000)──Rot(1.2000,3.2000,0.7000,"arbitrary")─╰RX(-2.3000)─┤ ╰<Z@Z> Parameters can be omitted by requesting ``decimals=None``: >>> print(qml.draw(circuit, decimals=None)(a=2.3, w=[1.2, 3.2, 0.7])) 0: ──H─╭●────────────────────╭●──┤ ╭<Z@Z> 1: ────╰RX──Rot("arbitrary")─╰RX─┤ ╰<Z@Z> If the parameters are not acted upon by classical processing like ``-a``, then ``qml.draw`` can handle string-valued parameters as well: >>> @qml.qnode(qml.device('lightning.qubit', wires=1)) ... def circuit2(x): ... qml.RX(x, wires=0) ... return qml.expval(qml.Z(0)) >>> print(qml.draw(circuit2)("x")) 0: ──RX(x)─┤ <Z> When requested with ``show_matrices=True`` (the default), matrix valued parameters are printed below the circuit. For ``show_matrices=False``, they are not printed: >>> @qml.qnode(qml.device('default.qubit', wires=2)) ... def circuit3(): ... qml.QubitUnitary(np.eye(2), wires=0) ... qml.QubitUnitary(-np.eye(4), wires=(0,1)) ... return qml.expval(qml.Hermitian(np.eye(2), wires=1)) >>> print(qml.draw(circuit3)()) 0: ──U(M0)─╭U(M1)─┤ 1: ────────╰U(M1)─┤ <𝓗(M0)> M0 = [[1. 0.] [0. 1.]] M1 = [[-1. -0. -0. -0.] [-0. -1. -0. -0.] [-0. -0. -1. -0.] [-0. -0. -0. -1.]] >>> print(qml.draw(circuit3, show_matrices=False)()) 0: ──U(M0)─╭U(M1)─┤ 1: ────────╰U(M1)─┤ <𝓗(M0)> The ``max_length`` keyword warps long circuits: .. code-block:: python rng = np.random.default_rng(seed=42) shape = qml.StronglyEntanglingLayers.shape(n_wires=3, n_layers=3) params = rng.random(shape) @qml.qnode(qml.device('lightning.qubit', wires=3)) def longer_circuit(params): qml.StronglyEntanglingLayers(params, wires=range(3)) return [qml.expval(qml.Z(i)) for i in range(3)] print(qml.draw(longer_circuit, max_length=60)(params)) .. code-block:: none 0: ──Rot(0.77,0.44,0.86)─╭●────╭X──Rot(0.45,0.37,0.93)─╭●─╭X 1: ──Rot(0.70,0.09,0.98)─╰X─╭●─│───Rot(0.64,0.82,0.44)─│──╰● 2: ──Rot(0.76,0.79,0.13)────╰X─╰●──Rot(0.23,0.55,0.06)─╰X─── ───Rot(0.83,0.63,0.76)──────────────────────╭●────╭X─┤ <Z> ──╭X────────────────────Rot(0.35,0.97,0.89)─╰X─╭●─│──┤ <Z> ──╰●────────────────────Rot(0.78,0.19,0.47)────╰X─╰●─┤ <Z> The ``wire_order`` keyword specifies the order of the wires from top to bottom: >>> print(qml.draw(circuit, wire_order=[1,0])(a=2.3, w=[1.2, 3.2, 0.7])) 1: ────╭RX(2.30)──Rot(1.20,3.20,0.70)─╭RX(-2.30)─┤ ╭<Z@Z> 0: ──H─╰●─────────────────────────────╰●─────────┤ ╰<Z@Z> If the device or ``wire_order`` has wires not used by operations, those wires are omitted unless requested with ``show_all_wires=True`` >>> empty_qfunc = lambda : qml.expval(qml.Z(0)) >>> empty_circuit = qml.QNode(empty_qfunc, qml.device('lightning.qubit', wires=3)) >>> print(qml.draw(empty_circuit, show_all_wires=True)()) 0: ───┤ <Z> 1: ───┤ 2: ───┤ Drawing also works on batch transformed circuits: .. code-block:: python from functools import partial @partial(qml.gradients.param_shift, shifts=[(0.1,)]) @qml.qnode(qml.device('default.qubit', wires=1)) def transformed_circuit(x): qml.RX(x, wires=0) return qml.expval(qml.Z(0)) print(qml.draw(transformed_circuit)(np.array(1.0, requires_grad=True))) .. code-block:: none 0: ──RX(1.10)─┤ <Z> 0: ──RX(0.90)─┤ <Z> The function also accepts quantum functions rather than QNodes. This can be especially helpful if you want to visualize only a part of a circuit that may not be convertible into a QNode, such as a sub-function that does not return any measurements. >>> def qfunc(x): ... qml.RX(x, wires=[0]) ... qml.CNOT(wires=[0, 1]) >>> print(qml.draw(qfunc)(1.1)) 0: ──RX(1.10)─╭●─┤ 1: ───────────╰X─┤ """ if catalyst_qjit(qnode): qnode = qnode.user_function if hasattr(qnode, "construct"): return _draw_qnode( qnode, wire_order=wire_order, show_all_wires=show_all_wires, decimals=decimals, max_length=max_length, show_matrices=show_matrices, expansion_strategy=expansion_strategy, ) if expansion_strategy is not None: warnings.warn( "When the input to qml.draw is not a QNode, the expansion_strategy argument is ignored.", UserWarning, ) @wraps(qnode) def wrapper(*args, **kwargs): tape = qml.tape.make_qscript(qnode)(*args, **kwargs) if wire_order: _wire_order = wire_order else: try: _wire_order = sorted(tape.wires) except TypeError: _wire_order = tape.wires return tape_text( tape, wire_order=_wire_order, show_all_wires=show_all_wires, decimals=decimals, show_matrices=show_matrices, max_length=max_length, ) return wrapper
def _draw_qnode( qnode, wire_order=None, show_all_wires=False, decimals=2, max_length=100, show_matrices=True, expansion_strategy=None, ): @wraps(qnode) def wrapper(*args, **kwargs): if isinstance(qnode.device, qml.devices.Device) and ( expansion_strategy == "device" or getattr(qnode, "expansion_strategy", None) == "device" ): qnode.construct(args, kwargs) tapes = qnode.transform_program([qnode.tape])[0] program, _ = qnode.device.preprocess() tapes = program(tapes)[0] else: original_expansion_strategy = getattr(qnode, "expansion_strategy", None) try: qnode.expansion_strategy = expansion_strategy or original_expansion_strategy tapes = qnode.construct(args, kwargs) program = qnode.transform_program tapes = program([qnode.tape])[0] finally: qnode.expansion_strategy = original_expansion_strategy if wire_order: _wire_order = wire_order elif qnode.device.wires: _wire_order = qnode.device.wires else: try: _wire_order = sorted(tapes[0].wires) except TypeError: _wire_order = tapes[0].wires if tapes is not None: cache = {"tape_offset": 0, "matrices": []} res = [ tape_text( t, wire_order=_wire_order, show_all_wires=show_all_wires, decimals=decimals, show_matrices=False, max_length=max_length, cache=cache, ) for t in tapes ] if show_matrices and cache["matrices"]: mat_str = "" for i, mat in enumerate(cache["matrices"]): mat_str += f"\nM{i} = \n{mat}" if mat_str: mat_str = "\n" + mat_str return "\n\n".join(res) + mat_str return "\n\n".join(res) return tape_text( qnode.qtape, wire_order=_wire_order, show_all_wires=show_all_wires, decimals=decimals, show_matrices=show_matrices, max_length=max_length, ) return wrapper
[docs]def draw_mpl( qnode, wire_order=None, show_all_wires=False, decimals=None, expansion_strategy=None, style=None, *, fig=None, **kwargs, ): """Draw a qnode with matplotlib Args: qnode (.QNode or Callable): the input QNode/quantum function that is to be drawn. Keyword Args: wire_order (Sequence[Any]): the order (from top to bottom) to print the wires of the circuit. If not provided, the wire order defaults to the device wires. If device wires are not available, the circuit wires are sorted if possible. show_all_wires (bool): If True, all wires, including empty wires, are printed. decimals (int): How many decimal points to include when formatting operation parameters. Default ``None`` will omit parameters from operation labels. style (str): visual style of plot. Valid strings are ``{'black_white', 'black_white_dark', 'sketch', 'pennylane', 'pennylane_sketch', 'sketch_dark', 'solarized_light', 'solarized_dark', 'default'}``. If no style is specified, the global style set with :func:`~.use_style` will be used, and the initial default is 'black_white'. If you would like to use your environment's current rcParams, set ``style`` to "rcParams". Setting style does not modify matplotlib global plotting settings. fontsize (float or str): fontsize for text. Valid strings are ``{'xx-small', 'x-small', 'small', 'medium', large', 'x-large', 'xx-large'}``. Default is ``14``. wire_options (dict): matplotlib formatting options for the wire lines label_options (dict): matplotlib formatting options for the wire labels active_wire_notches (bool): whether or not to add notches indicating active wires. Defaults to ``True``. expansion_strategy (str): The strategy to use when circuit expansions or decompositions are required. - ``gradient``: The QNode will attempt to decompose the internal circuit such that all circuit operations are supported by the gradient method. - ``device``: The QNode will attempt to decompose the internal circuit such that all circuit operations are natively supported by the device. fig (None or matplotlib.Figure): Matplotlib figure to plot onto. If None, then create a new figure Returns: A function that has the same argument signature as ``qnode``. When called, the function will draw the QNode as a tuple of (``matplotlib.figure.Figure``, ``matplotlib.axes._axes.Axes``) **Example**: .. code-block:: python dev = qml.device('lightning.qubit', wires=(0,1,2,3)) @qml.qnode(dev) def circuit(x, z): qml.QFT(wires=(0,1,2,3)) qml.IsingXX(1.234, wires=(0,2)) qml.Toffoli(wires=(0,1,2)) mcm = qml.measure(1) mcm_out = qml.measure(2) qml.CSWAP(wires=(0,2,3)) qml.RX(x, wires=0) qml.cond(mcm, qml.RY)(np.pi / 4, wires=3) qml.CRZ(z, wires=(3,0)) return qml.expval(qml.Z(0)), qml.probs(op=mcm_out) fig, ax = qml.draw_mpl(circuit)(1.2345,1.2345) fig.show() .. figure:: ../../_static/draw_mpl/main_example.png :align: center :width: 60% :target: javascript:void(0); .. details:: :title: Usage Details **Decimals:** The keyword ``decimals`` controls how many decimal points to include when labelling the operations. The default value ``None`` omits parameters for brevity. .. code-block:: python @qml.qnode(dev) def circuit2(x, y): qml.RX(x, wires=0) qml.Rot(*y, wires=0) return qml.expval(qml.Z(0)) fig, ax = qml.draw_mpl(circuit2, decimals=2)(1.23456, [1.2345,2.3456,3.456]) fig.show() .. figure:: ../../_static/draw_mpl/decimals.png :align: center :width: 60% :target: javascript:void(0); **Wires:** The keywords ``wire_order`` and ``show_all_wires`` control the location of wires from top to bottom. .. code-block:: python fig, ax = qml.draw_mpl(circuit, wire_order=[3,2,1,0])(1.2345,1.2345) fig.show() .. figure:: ../../_static/draw_mpl/wire_order.png :align: center :width: 60% :target: javascript:void(0); If a wire is in ``wire_order``, but not in the ``tape``, it will be omitted by default. Only by selecting ``show_all_wires=True`` will empty wires be displayed. .. code-block:: python fig, ax = qml.draw_mpl(circuit, wire_order=["aux"], show_all_wires=True)(1.2345,1.2345) fig.show() .. figure:: ../../_static/draw_mpl/show_all_wires.png :align: center :width: 60% :target: javascript:void(0); **Integration with matplotlib:** This function returns matplotlib figure and axes objects. Using these objects, users can perform further customization of the graphic. .. code-block:: python fig, ax = qml.draw_mpl(circuit)(1.2345,1.2345) fig.suptitle("My Circuit", fontsize="xx-large") options = {'facecolor': "white", 'edgecolor': "#f57e7e", "linewidth": 6, "zorder": -1} box1 = plt.Rectangle((-0.5, -0.5), width=3.0, height=4.0, **options) ax.add_patch(box1) ax.annotate("CSWAP", xy=(5, 2.5), xycoords='data', xytext=(5.8,1.5), textcoords='data', arrowprops={'facecolor': 'black'}, fontsize=14) ax.annotate("classical control flow", xy=(3.5, 4.2), xycoords='data', xytext=(0.8,4.2), textcoords='data', arrowprops={'facecolor': 'blue'}, fontsize=14, va="center") fig.show() .. figure:: ../../_static/draw_mpl/postprocessing.png :align: center :width: 60% :target: javascript:void(0); **Formatting:** PennyLane has inbuilt styles for controlling the appearance of the circuit drawings. All available styles can be determined by evaluating ``qml.drawer.available_styles()``. Any available string can then be passed via the kwarg ``style`` to change the settings for that plot. This will not affect style settings for subsequent matplotlib plots. .. code-block:: python fig, ax = qml.draw_mpl(circuit, style='sketch')(1.2345,1.2345) fig.show() .. figure:: ../../_static/draw_mpl/sketch_style.png :align: center :width: 60% :target: javascript:void(0); You can also control the appearance with matplotlib's provided tools, see the `matplotlib docs <https://matplotlib.org/stable/tutorials/introductory/customizing.html>`_ . For example, we can customize ``plt.rcParams``. To use a customized appearance based on matplotlib's ``plt.rcParams``, ``qml.draw_mpl`` must be run with ``style="rcParams"``: .. code-block:: python plt.rcParams['patch.facecolor'] = 'mistyrose' plt.rcParams['patch.edgecolor'] = 'maroon' plt.rcParams['text.color'] = 'maroon' plt.rcParams['font.weight'] = 'bold' plt.rcParams['patch.linewidth'] = 4 plt.rcParams['patch.force_edgecolor'] = True plt.rcParams['lines.color'] = 'indigo' plt.rcParams['lines.linewidth'] = 2 plt.rcParams['figure.facecolor'] = 'ghostwhite' fig, ax = qml.draw_mpl(circuit, style="rcParams")(1.2345,1.2345) fig.show() .. figure:: ../../_static/draw_mpl/rcparams.png :align: center :width: 60% :target: javascript:void(0); The wires and wire labels can be manually formatted by passing in dictionaries of keyword-value pairs of matplotlib options. ``wire_options`` accepts options for lines, and ``label_options`` accepts text options. .. code-block:: python fig, ax = qml.draw_mpl(circuit, wire_options={'color':'teal', 'linewidth': 5}, label_options={'size': 20})(1.2345,1.2345) fig.show() .. figure:: ../../_static/draw_mpl/wires_labels.png :align: center :width: 60% :target: javascript:void(0); """ if catalyst_qjit(qnode): qnode = qnode.user_function if hasattr(qnode, "construct"): return _draw_mpl_qnode( qnode, wire_order=wire_order, show_all_wires=show_all_wires, decimals=decimals, expansion_strategy=expansion_strategy, style=style, fig=fig, **kwargs, ) if expansion_strategy is not None: warnings.warn( "When the input to qml.draw is not a QNode, the expansion_strategy argument is ignored.", UserWarning, ) @wraps(qnode) def wrapper(*args, **kwargs): tape = qml.tape.make_qscript(qnode)(*args, **kwargs) if wire_order: _wire_order = wire_order else: try: _wire_order = sorted(tape.wires) except TypeError: _wire_order = tape.wires return tape_mpl( tape, wire_order=_wire_order, show_all_wires=show_all_wires, decimals=decimals, style=style, fig=fig, **kwargs, ) return wrapper
def _draw_mpl_qnode( qnode, wire_order=None, show_all_wires=False, decimals=None, expansion_strategy=None, style="black_white", *, fig=None, **kwargs, ): @wraps(qnode) def wrapper(*args, **kwargs_qnode): if expansion_strategy == "device" and isinstance(qnode.device, qml.devices.Device): qnode.construct(args, kwargs) tapes, _ = qnode.transform_program([qnode.tape]) program, _ = qnode.device.preprocess() tapes, _ = program(tapes) tape = tapes[0] else: original_expansion_strategy = getattr(qnode, "expansion_strategy", None) try: qnode.expansion_strategy = expansion_strategy or original_expansion_strategy qnode.construct(args, kwargs_qnode) program = qnode.transform_program [tape], _ = program([qnode.tape]) finally: qnode.expansion_strategy = original_expansion_strategy if wire_order: _wire_order = wire_order elif qnode.device.wires: _wire_order = qnode.device.wires else: try: _wire_order = sorted(tape.wires) except TypeError: _wire_order = tape.wires return tape_mpl( tape, wire_order=_wire_order, show_all_wires=show_all_wires, decimals=decimals, style=style, fig=fig, **kwargs, ) return wrapper