diff --git a/examples/compare_viz.py b/examples/compare_viz.py new file mode 100644 index 00000000..3d854489 --- /dev/null +++ b/examples/compare_viz.py @@ -0,0 +1,43 @@ +""" +Visualization Comparison Demo +============================== +This script opens both the static and the new interactive visualizers +side-by-side (or one after another) using the exact same pattern +so you can compare them visually as requested by the reviewers. +""" + +from __future__ import annotations + +import matplotlib.pyplot as plt + +from graphix.command import E, M, N, X, Z +from graphix.measurements import Measurement +from graphix.pattern import Pattern +from graphix.visualization_interactive import InteractiveGraphVisualizer + + +def main() -> None: + # Create the same simple pattern used in the interactive demo + p = Pattern(input_nodes=[0, 1]) + p.add(N(node=2)) + p.add(E(nodes=(0, 2))) + p.add(E(nodes=(1, 2))) + p.add(M(node=0, measurement=Measurement.XY(0.5))) + p.add(M(node=1, measurement=Measurement.XY(0.25))) + p.add(X(node=2, domain={0, 1})) + p.add(Z(node=2, domain={0})) + + print("Pattern created with", len(p), "commands.") + print("Close the static plot window to open the interactive one.") + + # 1. Show the static visualizer plot + p.draw_graph(flow_from_pattern=False, show_measurement_planes=True) + plt.show() # Blocks until closed + + # 2. Show the interactive visualizer plot + viz = InteractiveGraphVisualizer(p) + viz.visualize() + + +if __name__ == "__main__": + main() diff --git a/examples/interactive_viz_demo.py b/examples/interactive_viz_demo.py new file mode 100644 index 00000000..31accf69 --- /dev/null +++ b/examples/interactive_viz_demo.py @@ -0,0 +1,41 @@ +""" +Interactive Visualization Demo +============================== + +This example demonstrates the interactive graph visualizer using a simple +manually constructed pattern. It shows how to step through the visualization +and observe state changes. +""" + +from __future__ import annotations + +from graphix.command import E, M, N, X, Z +from graphix.measurements import Measurement +from graphix.pattern import Pattern +from graphix.visualization_interactive import InteractiveGraphVisualizer + + +def main() -> None: + # optimized pattern for QFT + # Create a simple pattern manually for demonstration + p = Pattern(input_nodes=[0, 1]) + p.add(N(node=2)) + p.add(E(nodes=(0, 2))) + p.add(E(nodes=(1, 2))) + p.add(M(node=0, measurement=Measurement.XY(0.5))) + p.add(M(node=1, measurement=Measurement.XY(0.25))) + p.add(X(node=2, domain={0, 1})) + p.add(Z(node=2, domain={0})) + + # Or standardization to make it interesting + # p.standardize() + + print("Pattern created with", len(p), "commands.") + print("Launching interactive visualization with real-time simulation...") + + viz = InteractiveGraphVisualizer(p) + viz.visualize() + + +if __name__ == "__main__": + main() diff --git a/examples/interactive_viz_qaoa.py b/examples/interactive_viz_qaoa.py new file mode 100644 index 00000000..b830e560 --- /dev/null +++ b/examples/interactive_viz_qaoa.py @@ -0,0 +1,64 @@ +""" +QAOA Interactive Visualization (Optimized) +========================================== + +This example generates a QAOA pattern using the Graphix Circuit API +and launches the interactive visualizer in simulation-free mode +to demonstrate performance on complex patterns. +""" + +from __future__ import annotations + +import networkx as nx +import numpy as np + +from graphix import Circuit +from graphix.visualization_interactive import InteractiveGraphVisualizer + + +def main() -> None: + print("Generating QAOA pattern...") + + # 1. Define QAOA Circuit + n_qubits = 4 + rng = np.random.default_rng(42) # Fixed seed for reproducibility + + # Random parameters for the circuit + xi = rng.random(6) + theta = rng.random(4) + + # Create a complete graph for the problem hamiltonian + g = nx.complete_graph(n_qubits) + circuit = Circuit(n_qubits) + + # Apply unitary evolution for the problem Hamiltonian + for i, (u, v) in enumerate(g.edges): + circuit.cnot(u, v) + circuit.rz(v, float(xi[i])) # Rotation by random angle + circuit.cnot(u, v) + + # Apply unitary evolution for the mixing Hamiltonian + for v in g.nodes: + circuit.rx(v, float(theta[v])) + + # 2. Transpile to MBQC Pattern + # This automatically generates the measurement pattern from the gate circuit + pattern = circuit.transpile().pattern + + # Standardize the pattern to ensure it follows the standard MBQC form (N, E, M, C) + pattern.standardize() + pattern.shift_signals() + + print(f"Pattern generated with {len(pattern)} commands.") + print("Launching interactive visualizer...") + print("Optimization enabled: Simulation is DISABLED for performance.") + print("You will see the graph structure and command flow without quantum state calculation.") + + # 3. Launch Visualization + # enable_simulation=False prevents high RAM usage for this complex pattern + viz = InteractiveGraphVisualizer(pattern, node_distance=(1.5, 1.5), enable_simulation=False) + viz.visualize() + + +if __name__ == "__main__": + main() diff --git a/graphix/visualization.py b/graphix/visualization.py index 4322b82b..45be3f78 100644 --- a/graphix/visualization.py +++ b/graphix/visualization.py @@ -24,6 +24,7 @@ from typing import TypeAlias, TypeVar import numpy.typing as npt + from matplotlib.axes import Axes from graphix.clifford import Clifford from graphix.flow.core import CausalFlow, PauliFlow @@ -52,41 +53,25 @@ class GraphVisualizer: og: OpenGraph[Measurement] local_clifford: Mapping[int, Clifford] | None = None - def visualize( + def get_layout( self, - show_pauli_measurement: bool = True, - show_local_clifford: bool = False, - show_measurement_planes: bool = False, - show_loop: bool = True, - node_distance: tuple[float, float] = (1, 1), - figsize: tuple[int, int] | None = None, - filename: Path | None = None, - ) -> None: - """ - Visualize the graph with flow or gflow structure. - - If there exists a flow structure, then the graph is visualized with the flow structure. - If flow structure is not found and there exists a gflow structure, then the graph is visualized - with the gflow structure. - If neither flow nor gflow structure is found, then the graph is visualized without any structure. + ) -> tuple[ + Mapping[int, _Point], + Callable[ + [Mapping[int, _Point]], tuple[Mapping[_Edge, Sequence[_Point]], Mapping[_Edge, Sequence[_Point]] | None] + ], + Mapping[int, int] | None, + ]: + """Determine the layout (positions, paths, layers) for the graph. - Parameters - ---------- - show_pauli_measurement : bool - If True, the nodes with Pauli measurement angles are colored light blue. - show_local_clifford : bool - If True, indexes of the local Clifford operator are displayed adjacent to the nodes. - show_measurement_planes : bool - If True, the measurement planes are displayed adjacent to the nodes. - show_loop : bool - whether or not to show loops for graphs with gflow. defaulted to True. - node_distance : tuple - Distance multiplication factor between nodes for x and y directions. - figsize : tuple - Figure size of the plot. - filename : Path | None - If not None, filename of the png file to save the plot. If None, the plot is not saved. - Default in None. + Returns + ------- + pos : dict + Node positions. + place_paths : callable + Function to place edges and arrows. + l_k : dict or None + Layer mapping. """ try: bloch_graph = self.og.downcast_bloch() @@ -131,6 +116,46 @@ def place_paths( ) -> tuple[Mapping[_Edge, Sequence[_Point]], Mapping[_Edge, Sequence[_Point]] | None]: return (self.place_edge_paths_without_structure(pos), None) + return pos, place_paths, l_k + + def visualize( + self, + show_pauli_measurement: bool = True, + show_local_clifford: bool = False, + show_measurement_planes: bool = False, + show_loop: bool = True, + node_distance: tuple[float, float] = (1, 1), + figsize: tuple[int, int] | None = None, + filename: Path | None = None, + ) -> None: + """ + Visualize the graph with flow or gflow structure. + + If there exists a flow structure, then the graph is visualized with the flow structure. + If flow structure is not found and there exists a gflow structure, then the graph is visualized + with the gflow structure. + If neither flow nor gflow structure is found, then the graph is visualized without any structure. + + Parameters + ---------- + show_pauli_measurement : bool + If True, the nodes with Pauli measurement angles are colored light blue. + show_local_clifford : bool + If True, indexes of the local Clifford operator are displayed adjacent to the nodes. + show_measurement_planes : bool + If True, the measurement planes are displayed adjacent to the nodes. + show_loop : bool + whether or not to show loops for graphs with gflow. defaulted to True. + node_distance : tuple + Distance multiplication factor between nodes for x and y directions. + figsize : tuple + Figure size of the plot. + filename : Path | None + If not None, filename of the png file to save the plot. If None, the plot is not saved. + Default in None. + """ + pos, place_paths, l_k = self.get_layout() + self.visualize_graph( pos, place_paths, @@ -253,11 +278,176 @@ def _shorten_path(path: Sequence[_Point]) -> list[_Point]: return new_path def _draw_labels(self, pos: Mapping[int, _Point]) -> None: - fontsize = 12 - if max(self.og.graph.nodes(), default=0) >= 100: - fontsize = int(fontsize * 2 / len(str(max(self.og.graph.nodes())))) + fontsize = self.get_label_fontsize(max(self.og.graph.nodes(), default=0)) nx.draw_networkx_labels(self.og.graph, pos, font_size=fontsize) + def draw_node_labels( + self, + ax: Axes, + pos: Mapping[int, _Point], + extra_labels: Mapping[int, str] | None = None, + fontsize: int | None = None, + ) -> None: + """Draw node labels onto a given axes object. + + This is an axis-aware counterpart of :meth:`_draw_labels` intended for + use in contexts where the caller manages the :class:`~matplotlib.axes.Axes` + directly (e.g. the interactive visualizer). + + Parameters + ---------- + ax : Axes + The matplotlib axes to draw onto. + pos : Mapping[int, tuple[float, float]] + Dictionary mapping each node to its ``(x, y)`` position. + extra_labels : Mapping[int, str] or None, optional + If provided, appends the corresponding string below the node number. + fontsize : int or None, optional + Font size for the labels. If ``None``, it is computed automatically. + """ + if fontsize is None: + fontsize = self.get_label_fontsize(max(self.og.graph.nodes(), default=0)) + for node, (x, y) in pos.items(): + label_text = str(node) + if extra_labels is not None and node in extra_labels: + label_text += f"\n{extra_labels[node]}" + ax.text(x, y, label_text, ha="center", va="center", fontsize=fontsize, zorder=3) + + @staticmethod + def get_label_fontsize(max_node: int, base_size: int = 12) -> int: + """Compute the font size for node labels. + + When the largest node number has many digits the font is reduced + so that labels still fit inside the scatter markers. + + Parameters + ---------- + max_node : int + The largest node number in the graph. + base_size : int, optional + The default font size used for small node numbers. + Defaults to ``12``. + + Returns + ------- + int + The computed font size, never smaller than ``7``. + """ + if max_node >= 100: + return max(7, int(base_size * 2 / len(str(max_node)))) + return base_size + + def draw_edges( + self, + ax: Axes, + pos: Mapping[int, _Point], + edge_subset: Iterable[tuple[int, ...]] | None = None, + ) -> None: + """Draw graph edges as plain lines onto a given axes object. + + This axis-aware method is intended for use in contexts where the caller + manages the :class:`~matplotlib.axes.Axes` directly (e.g. the + interactive visualizer). + + Parameters + ---------- + ax : Axes + The matplotlib axes to draw onto. + pos : Mapping[int, tuple[float, float]] + Dictionary mapping each node to its ``(x, y)`` position. + edge_subset : Iterable[tuple[int, int]] or None, optional + If provided, only these edges are drawn. When ``None`` + (the default), all edges in :attr:`og.graph` are drawn. + """ + edges: Iterable[tuple[int, ...]] = self.og.graph.edges() if edge_subset is None else edge_subset + for u, v in edges: + if u in pos and v in pos: + x1, y1 = pos[u] + x2, y2 = pos[v] + ax.plot([x1, x2], [y1, y2], color="black", alpha=0.7, zorder=1) + + def draw_nodes_role( + self, + ax: Axes, + pos: Mapping[int, _Point], + show_pauli_measurement: bool = False, + node_facecolors: Mapping[int, str] | None = None, + node_edgecolors: Mapping[int, str] | None = None, + node_alpha: Mapping[int, float] | None = None, + node_linewidths: Mapping[int, float] | None = None, + node_size: int = 350, + ) -> None: + """Draw nodes onto a given axes object, coloured by their role. + + This is an axis-aware counterpart of the private ``__draw_nodes_role`` + method, intended for use in contexts where the caller manages the + :class:`~matplotlib.axes.Axes` directly (e.g. the interactive + visualizer). Nodes are styled as follows: + + * Input nodes: red border, white fill. + * Output nodes: black border, light-gray fill. + * Pauli-measured nodes (when *show_pauli_measurement* is ``True``): + black border, light-blue fill. + * All other nodes: black border, white fill. + + When *node_facecolors*, *node_edgecolors*, or *node_alpha* are provided, + their values override the role-based defaults for the corresponding nodes. + + Parameters + ---------- + ax : Axes + The matplotlib axes to draw onto. + pos : Mapping[int, tuple[float, float]] + Dictionary mapping each node to its ``(x, y)`` position. + show_pauli_measurement : bool, optional + If ``True``, nodes with Pauli measurement angles are coloured + light blue. Defaults to ``False``. + node_facecolors : Mapping[int, str] or None, optional + Per-node fill colour overrides. + node_edgecolors : Mapping[int, str] or None, optional + Per-node border colour overrides. + node_alpha : Mapping[int, float] or None, optional + Per-node opacity overrides. (Default role alpha is 1.0). + node_linewidths : Mapping[int, float] or None, optional + Per-node line width overrides for the marker edge. + node_size : int, optional + Marker size for :meth:`~matplotlib.axes.Axes.scatter`. + Defaults to ``350``. + """ + for node in self.og.graph.nodes(): + if node not in pos: + continue + edgecolor = "black" + facecolor = "white" + alpha = 1.0 + if node in self.og.input_nodes: + edgecolor = "red" + if node in self.og.output_nodes: + facecolor = "lightgray" + elif show_pauli_measurement and isinstance(self.og.measurements[node], PauliMeasurement): + facecolor = "lightblue" + # Apply per-node overrides if provided + if node_facecolors is not None and node in node_facecolors: + facecolor = node_facecolors[node] + if node_edgecolors is not None and node in node_edgecolors: + edgecolor = node_edgecolors[node] + if node_alpha is not None and node in node_alpha: + alpha = node_alpha[node] + + lw = 1.5 + if node_linewidths is not None and node in node_linewidths: + lw = node_linewidths[node] + + ax.scatter( + *pos[node], + edgecolors=edgecolor, + facecolors=facecolor, + s=node_size, + zorder=2, + linewidths=lw, + alpha=alpha, + ) + def __draw_nodes_role(self, pos: Mapping[int, _Point], show_pauli_measurement: bool = False) -> None: """ Draw the nodes with different colors based on their role (input, output, or other). @@ -347,49 +537,13 @@ def visualize_graph( plt.figure(figsize=figsize) - for edge, path in edge_path.items(): - if len(path) == 2: - nx.draw_networkx_edges(self.og.graph, pos, edgelist=[edge], style="dashed", alpha=0.7) - else: - curve = self._bezier_curve_linspace(path) - plt.plot(curve[:, 0], curve[:, 1], "k--", linewidth=1, alpha=0.7) + ax = plt.gca() if arrow_path is not None: - for arrow, path in arrow_path.items(): - if corrections is None: - color = "k" - else: - xflow, zflow = corrections - if arrow[1] not in xflow.get(arrow[0], set()): - color = "tab:green" - elif arrow[1] not in zflow.get(arrow[0], set()): - color = "tab:red" - else: - color = "tab:brown" - if arrow[0] == arrow[1]: # self loop - if show_loop: - curve = self._bezier_curve_linspace(path) - plt.plot(curve[:, 0], curve[:, 1], c="k", linewidth=1) - plt.annotate( - "", - xy=curve[-1], - xytext=curve[-2], - arrowprops={"arrowstyle": "->", "color": color, "lw": 1}, - ) - elif len(path) == 2: # straight line - nx.draw_networkx_edges( - self.og.graph, pos, edgelist=[arrow], edge_color=color, arrowstyle="->", arrows=True - ) - else: - new_path = GraphVisualizer._shorten_path(path) - curve = self._bezier_curve_linspace(new_path) - plt.plot(curve[:, 0], curve[:, 1], c=color, linewidth=1) - plt.annotate( - "", - xy=curve[-1], - xytext=curve[-2], - arrowprops={"arrowstyle": "->", "color": color, "lw": 1}, - ) + self.draw_edges_with_routing(ax, edge_path) + self.draw_flow_arrows(ax, pos, arrow_path, corrections, show_loop) + else: + self.draw_edges_with_routing(ax, edge_path) self.__draw_nodes_role(pos, show_pauli_measurement) @@ -409,21 +563,13 @@ def visualize_graph( plt.plot([], [], color="tab:brown", label="xflow and zflow") plt.legend(loc="center left", fontsize=10, bbox_to_anchor=(1, 0.5)) - x_min = min((pos[node][0] for node in self.og.graph.nodes()), default=0) # Get the minimum x coordinate - x_max = max((pos[node][0] for node in self.og.graph.nodes()), default=0) # Get the maximum x coordinate - y_min = min((pos[node][1] for node in self.og.graph.nodes()), default=0) # Get the minimum y coordinate - y_max = max((pos[node][1] for node in self.og.graph.nodes()), default=0) # Get the maximum y coordinate - - if l_k is not None and l_k: - # Draw the vertical lines to separate different layers - for layer in range(min(l_k.values()), max(l_k.values())): - plt.axvline( - x=(layer + 0.5) * node_distance[0], color="gray", linestyle="--", alpha=0.5 - ) # Draw line between layers - for layer in range(min(l_k.values()), max(l_k.values()) + 1): - plt.text( - layer * node_distance[0], y_min - 0.5, f"L: {max(l_k.values()) - layer}", ha="center", va="top" - ) # Add layer label at bottom + x_min = min((pos[node][0] for node in self.og.graph.nodes()), default=0) + x_max = max((pos[node][0] for node in self.og.graph.nodes()), default=0) + y_min = min((pos[node][1] for node in self.og.graph.nodes()), default=0) + y_max = max((pos[node][1] for node in self.og.graph.nodes()), default=0) + + if l_k is not None: + self.draw_layer_separators(ax, pos, l_k, node_distance) plt.xlim( x_min - 0.5 * node_distance[0], x_max + 0.5 * node_distance[0] @@ -480,6 +626,162 @@ def determine_figsize( height = len({pos[node][1] for node in self.og.graph.nodes()}) if pos is not None else len(self.og.output_nodes) return (width * node_distance[0], height * node_distance[1]) + def draw_edges_with_routing( + self, + ax: Axes, + edge_path: Mapping[_Edge, Sequence[_Point]], + edge_subset: Iterable[_Edge] | None = None, + edge_colors: Mapping[_Edge, str] | None = None, + edge_linewidths: Mapping[_Edge, float] | None = None, + ) -> None: + """Draw graph edges along provided routed paths. + + Parameters + ---------- + ax : Axes + The matplotlib axes to draw onto. + edge_path : Mapping[_Edge, Sequence[_Point]] + Mapping from edge to its routed path (from ``get_layout``). + edge_subset : Iterable[tuple[int, int]] or None, optional + If provided, only these edges are drawn. When ``None``, all provided + edges in ``edge_path`` are drawn. + edge_colors : Mapping[tuple[int, int], str] or None, optional + Per-edge colour overrides. Default is black (k). + edge_linewidths : Mapping[tuple[int, int], float] or None, optional + Per-edge linewidth overrides. Default is 1.0. + """ + allowed_edges = ( + {(min(e), max(e)) for e in self.og.graph.edges()} + if edge_subset is None + else {(min(e), max(e)) for e in edge_subset} + ) + for edge, path in edge_path.items(): + if (min(edge), max(edge)) not in allowed_edges: + continue + + e_sorted = (min(edge), max(edge)) + color = "k" + lw = 1.0 + if edge_colors is not None and e_sorted in edge_colors: + color = edge_colors[e_sorted] + if edge_linewidths is not None and e_sorted in edge_linewidths: + lw = edge_linewidths[e_sorted] + + if len(path) == 2: + ax.plot( + [path[0][0], path[1][0]], + [path[0][1], path[1][1]], + color=color, + linewidth=lw, + linestyle="--", + alpha=0.7, + ) + else: + curve = self._bezier_curve_linspace(path) + ax.plot(curve[:, 0], curve[:, 1], color=color, linewidth=lw, linestyle="--", alpha=0.7) + + def draw_flow_arrows( + self, + ax: Axes, + pos: Mapping[int, _Point], + arrow_path: Mapping[_Edge, Sequence[_Point]], + corrections: tuple[Mapping[int, AbstractSet[int]], Mapping[int, AbstractSet[int]]] | None = None, + show_loop: bool = True, + arrow_subset: Iterable[_Edge] | None = None, + ) -> None: + """Draw flow/gflow arrows along routed paths. + + Parameters + ---------- + ax : Axes + The matplotlib axes to draw onto. + pos : Mapping[int, tuple[float, float]] + Dictionary mapping each node to its ``(x, y)`` position. + arrow_path : Mapping[_Edge, Sequence[_Point]] + Mapping from edge to its routed path (from ``get_layout``). + corrections : tuple or None + X and Z corrections, used to color the arrows. + show_loop : bool, optional + Whether to show loops for graphs with gflow. Defaults to True. + arrow_subset : Iterable[tuple[int, int]] or None, optional + If provided, only these arrows are drawn. + """ + if arrow_subset is not None: + allowed = {(min(e), max(e)) for e in arrow_subset} + arrows_to_draw = [a for a in arrow_path if (min(a), max(a)) in allowed] + else: + arrows_to_draw = list(arrow_path.keys()) + + for arrow in arrows_to_draw: + if arrow not in arrow_path: + continue + path = arrow_path[arrow] + if corrections is None: + color = "k" + else: + xflow, zflow = corrections + if arrow[1] not in xflow.get(arrow[0], set()): + color = "tab:green" + elif arrow[1] not in zflow.get(arrow[0], set()): + color = "tab:red" + else: + color = "tab:brown" + if arrow[0] == arrow[1]: # self loop + if show_loop: + curve = self._bezier_curve_linspace(path) + ax.plot(curve[:, 0], curve[:, 1], c="k", linewidth=1) + ax.annotate( + "", + xy=curve[-1], + xytext=curve[-2], + arrowprops={"arrowstyle": "->", "color": color, "lw": 1}, + ) + elif len(path) == 2: # straight line + # nx draws standard arrows + nx.draw_networkx_edges( + self.og.graph, pos, edgelist=[arrow], edge_color=color, arrowstyle="->", arrows=True, ax=ax + ) + else: + new_path = GraphVisualizer._shorten_path(path) + curve = self._bezier_curve_linspace(new_path) + ax.plot(curve[:, 0], curve[:, 1], c=color, linewidth=1) + ax.annotate( + "", + xy=curve[-1], + xytext=curve[-2], + arrowprops={"arrowstyle": "->", "color": color, "lw": 1}, + ) + + @staticmethod + def draw_layer_separators( + ax: Axes, + pos: Mapping[int, _Point], + l_k: Mapping[int, int], + node_distance: tuple[float, float] = (1, 1), + ) -> None: + """Draw vertical dashed lines and labels to separate graph layers. + + Parameters + ---------- + ax : Axes + The matplotlib axes to draw onto. + pos : Mapping[int, tuple[float, float]] + Dictionary mapping each node to its ``(x, y)`` position. + l_k : Mapping[int, int] + Mapping from node to its layer index. + node_distance : tuple[float, float], optional + Distance scaling for the positions. + """ + if not l_k: + return + y_vals = [p[1] for p in pos.values()] + y_min = min(y_vals) if y_vals else 0 + min_l, max_l = min(l_k.values()), max(l_k.values()) + for layer in range(min_l, max_l): + ax.axvline(x=(layer + 0.5) * node_distance[0], color="gray", linestyle="--", alpha=0.5) + for layer in range(min_l, max_l + 1): + ax.text(layer * node_distance[0], y_min - 0.5, f"L: {max_l - layer}", ha="center", va="top") + def place_edge_paths( self, flow: Mapping[int, AbstractSet[int]], pos: Mapping[int, _Point] ) -> tuple[dict[_Edge, list[_Point]], dict[_Edge, list[_Point]]]: diff --git a/graphix/visualization_interactive.py b/graphix/visualization_interactive.py new file mode 100644 index 00000000..76c48679 --- /dev/null +++ b/graphix/visualization_interactive.py @@ -0,0 +1,637 @@ +"""Interactive visualization for MBQC patterns.""" + +from __future__ import annotations + +import sys +import traceback +from typing import TYPE_CHECKING, Any + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.text import Text +from matplotlib.widgets import Button, Slider + +from graphix.clifford import Clifford +from graphix.command import CommandKind +from graphix.opengraph import OpenGraph +from graphix.pretty_print import OutputFormat, command_to_str +from graphix.sim.statevec import StatevectorBackend +from graphix.visualization import GraphVisualizer + +if TYPE_CHECKING: + from graphix.pattern import Pattern + + +class InteractiveGraphVisualizer: + """Interactive visualization tool for MBQC patterns. + + Attributes + ---------- + pattern : Pattern + The MBQC pattern to visualize. + node_distance : tuple[float, float] + Scale factors (x, y) for the node positions. + enable_simulation : bool + If True, simulates the state vector and measurement outcomes. + marker_fill_ratio : float + Fraction of the inter-node spacing used by each marker (0-1). + label_size_ratio : float + Label font size as a fraction of the marker diameter. + max_label_fontsize : int + Upper bound for label font size in points. + min_inches_per_node : float + Minimum vertical inches per node for adaptive figure height. + active_node_color : str + Border colour for active (current-step) nodes. + measured_node_color : str + Fill colour for already-measured nodes. + """ + + def __init__( + self, + pattern: Pattern, + node_distance: tuple[float, float] = (1, 1), + enable_simulation: bool = True, + *, + marker_fill_ratio: float = 0.80, + label_size_ratio: float = 0.55, + max_label_fontsize: int = 12, + min_inches_per_node: float = 0.3, + active_node_color: str = "#2060cc", + measured_node_color: str = "lightgray", + ) -> None: + """Construct an interactive visualizer. + + Parameters + ---------- + pattern : Pattern + The MBQC pattern to visualize. + node_distance : tuple[float, float], optional + Scale factors (x, y) for node positions. Defaults to (1, 1). + enable_simulation : bool, optional + If True, enables state vector simulation. Defaults to True. + marker_fill_ratio : float, optional + Fraction of the inter-node spacing used by marker diameter. + Defaults to 0.80. + label_size_ratio : float, optional + Label font size as a fraction of the marker diameter in points. + Defaults to 0.55. + max_label_fontsize : int, optional + Upper bound for label font size. Prevents text from overflowing + the marker in sparse graphs. Defaults to 12. + min_inches_per_node : float, optional + Minimum vertical inches allocated per node when computing the + adaptive figure height. Defaults to 0.3. + active_node_color : str, optional + Border colour for active nodes. Defaults to ``"#2060cc"``. + measured_node_color : str, optional + Fill colour for measured nodes. Defaults to ``"lightgray"``. + """ + self.pattern = pattern + self.node_positions: dict[int, tuple[float, float]] = {} + self.node_distance = node_distance + self.enable_simulation = enable_simulation + self.marker_fill_ratio = marker_fill_ratio + self.label_size_ratio = label_size_ratio + self.max_label_fontsize = max_label_fontsize + self.min_inches_per_node = min_inches_per_node + self.active_node_color = active_node_color + self.measured_node_color = measured_node_color + + # Prepare graph layout reusing GraphVisualizer + self._prepare_layout() + + # Figure height and width adapts to graph density like GraphVisualizer + ax_h_frac = 0.80 + fig_width, needed_height = self._graph_visualizer.determine_figsize(self._l_k, pos=self.node_positions) + fig_height = max(7.0, needed_height / ax_h_frac) + + # Enforce minimum window width for controls + fig_width = max(4.0, fig_width) + + self.fig = plt.figure(figsize=(fig_width, fig_height)) + + # Dynamically scale command strip capacity down on narrow graphs + self.command_window_size = max(1, min(7, int(fig_width / 1.8))) + + # Ensure command_window_size is always an odd number so the current command is perfectly centered. + if self.command_window_size % 2 == 0: + self.command_window_size -= 1 + + self.node_size = 350 + self.label_fontsize = 10 + + # Axes layout fractions [left, bottom, width, height] + self.ax_graph = self.fig.add_axes((0.00, 0.15, 1.0, 0.85)) + self.ax_commands = self.fig.add_axes((0.10, 0.05, 0.80, 0.08)) + + self.ax_prev = self.fig.add_axes((0.32, 0.015, 0.04, 0.04)) + self.ax_slider = self.fig.add_axes((0.40, 0.015, 0.20, 0.04)) + self.ax_next = self.fig.add_axes((0.64, 0.015, 0.04, 0.04)) + + # Turn off axes frame for command list and graph + self.ax_commands.axis("off") + self.ax_graph.axis("off") + + # Interaction state + self.current_step = 0 + self.total_steps = len(pattern) + self.command_window_size = 7 # Increased for horizontal viewing + + # Widget placeholders + self.slider: Slider | None = None + self.btn_prev: Button | None = None + self.btn_next: Button | None = None + + def _prepare_layout(self) -> None: + """Compute node positions by reusing :class:`GraphVisualizer` layout. + + Builds the full graph from the pattern commands, delegates layout + computation to :meth:`GraphVisualizer.get_layout`, and normalizes + the resulting positions to fit the interactive panel area. + The flow-based layout is always preserved. + """ + # Build the full graph from all commands + g: Any = __import__("networkx").Graph() + measurements: dict[int, Any] = {} + for cmd in self.pattern: + if cmd.kind == CommandKind.N: + g.add_node(cmd.node) + elif cmd.kind == CommandKind.E: + g.add_edge(cmd.nodes[0], cmd.nodes[1]) + elif cmd.kind == CommandKind.M: + measurements[cmd.node] = cmd.measurement + + # Delegate layout to GraphVisualizer (shares flow-detection logic) + og = OpenGraph(g, self.pattern.input_nodes, self.pattern.output_nodes, measurements) + og = og.infer_pauli_measurements() + + vis = GraphVisualizer(og) + pos_mapping, self._place_paths, self._l_k = vis.get_layout() + self.node_positions = dict(pos_mapping) + + # Apply user-provided scaling + self.node_positions = { + k: (v[0] * self.node_distance[0], v[1] * self.node_distance[1]) for k, v in self.node_positions.items() + } + # Store the visualizer for reuse in drawing helpers + self._graph_visualizer = vis + + def visualize(self) -> None: + """Launch the interactive visualization window.""" + # Initial state simulation + state = self._update_graph_state(0) + + # Initial draw + self._draw_command_list(state[4]) # pass results dict + self._draw_graph(state) + self._update(0) + + # Step slider (horizontal, bottom centered, without label text) + self.slider = Slider(self.ax_slider, "", 0, self.total_steps, valinit=0, valstep=1, color="lightblue") + self.slider.valtext.set_visible(False) + self.slider.on_changed(self._update) + + # Buttons config + self.btn_prev = Button(self.ax_prev, "<") + self.btn_prev.on_clicked(self._prev_step) + + self.btn_next = Button(self.ax_next, ">") + self.btn_next.on_clicked(self._next_step) + + # Key events + self.fig.canvas.mpl_connect("key_press_event", self._on_key) + + # Pick events for command list + self.fig.canvas.mpl_connect("pick_event", self._on_pick) + + plt.show() + + def _draw_command_list(self, results: dict[int, int]) -> None: + self.ax_commands.clear() + self.ax_commands.axis("off") + + # Use current step as center of the visible window + half_window = self.command_window_size // 2 + start = max(0, self.current_step - half_window) + end = min(self.total_steps, self.current_step + half_window + 1) + + def _get_props(abs_idx: int, cmd: Any) -> tuple[str, str, str, str, int, float]: + text_str = command_to_str(cmd, OutputFormat.Unicode) + meas_str = "" + if cmd.kind == CommandKind.M and abs_idx <= self.current_step and cmd.node in results: + meas_str = f"m={results[cmd.node]}" + + color = "gray" + weight = "normal" + fontsize = 11 + alpha = 1.0 + + if abs_idx == self.current_step: + color = "black" + weight = "bold" + fontsize = 13 + alpha = 1.0 + elif abs_idx < self.current_step: + color = "black" + alpha = 0.4 + elif abs_idx > self.current_step: + color = "lightgray" + alpha = 0.7 + + return text_str, meas_str, color, weight, fontsize, alpha + + artists: dict[int, Any] = {} + + # Handle out-of-bounds slider focus + focus_idx = min(self.current_step, end - 1) + if focus_idx < start: + return + + cmd = self.pattern[focus_idx] + txt, meas_str, color, weight, fsize, alpha = _get_props(focus_idx, cmd) + artists[focus_idx] = self.ax_commands.text( + 0.5, + 0.5, + txt, + color=color, + weight=weight, + fontsize=fsize, + alpha=alpha, + transform=self.ax_commands.transAxes, + ha="center", + va="center", + picker=True, + clip_on=True, + ) + artists[focus_idx].index = focus_idx + + if meas_str: + self.ax_commands.annotate( + meas_str, + xy=(0.5, 1.0), + xycoords=artists[focus_idx], + xytext=(0, 2), + textcoords="offset points", + color=color, + fontsize=10, + alpha=alpha, + ha="center", + va="bottom", + annotation_clip=True, + clip_on=True, + ) + + # Draw past commands + prev_idx = focus_idx + for abs_idx in range(focus_idx - 1, start - 1, -1): + cmd = self.pattern[abs_idx] + txt, meas_str, color, weight, fsize, alpha = _get_props(abs_idx, cmd) + + artists[abs_idx] = self.ax_commands.annotate( + txt, + xy=(0, 0.5), + xycoords=artists[prev_idx], + xytext=(-4, 0), + textcoords="offset points", + color=color, + weight=weight, + fontsize=fsize, + alpha=alpha, + ha="right", + va="center", + picker=True, + annotation_clip=True, + clip_on=True, + ) + artists[abs_idx].index = abs_idx + + if meas_str: + self.ax_commands.annotate( + meas_str, + xy=(0.5, 1.0), + xycoords=artists[abs_idx], + xytext=(0, 2), + textcoords="offset points", + color=color, + fontsize=9, + alpha=alpha, + ha="center", + va="bottom", + annotation_clip=True, + clip_on=True, + ) + + prev_idx = abs_idx + + # Draw future commands + prev_idx = focus_idx + for abs_idx in range(focus_idx + 1, end): + cmd = self.pattern[abs_idx] + txt, meas_str, color, weight, fsize, alpha = _get_props(abs_idx, cmd) + + artists[abs_idx] = self.ax_commands.annotate( + txt, + xy=(1, 0.5), + xycoords=artists[prev_idx], + xytext=(4, 0), + textcoords="offset points", + color=color, + weight=weight, + fontsize=fsize, + alpha=alpha, + ha="left", + va="center", + picker=True, + annotation_clip=True, + clip_on=True, + ) + artists[abs_idx].index = abs_idx + + if meas_str: + self.ax_commands.annotate( + meas_str, + xy=(0.5, 1.0), + xycoords=artists[abs_idx], + xytext=(0, 2), + textcoords="offset points", + color=color, + fontsize=9, + alpha=alpha, + ha="center", + va="bottom", + annotation_clip=True, + clip_on=True, + ) + + prev_idx = abs_idx + + def _update_graph_state( + self, step: int + ) -> tuple[set[int], set[int], list[tuple[int, int]], dict[int, set[str]], dict[int, int]]: + """Calculate the graph state by simulating the pattern up to *step*. + + Parameters + ---------- + step : int + The command index up to which the pattern is executed. + + Returns + ------- + active_nodes : set[int] + Nodes that have been initialised but not yet measured. + measured_nodes : set[int] + Nodes that have been measured. + active_edges : list[tuple[int, int]] + Edges currently present in the graph (both endpoints active). + corrections : dict[int, set[str]] + Accumulated byproduct corrections per node (``"X"`` and/or ``"Z"``). + results : dict[int, int] + Measurement outcomes keyed by node (only populated when + *enable_simulation* is ``True``). + """ + active_nodes = set() + measured_nodes = set() + active_edges = [] + corrections: dict[int, set[str]] = {} + results: dict[int, int] = {} + + if self.enable_simulation: + backend = StatevectorBackend() + + # Prerun input nodes (standard MBQC initialization) + for node in self.pattern.input_nodes: + backend.add_nodes([node]) + + rng = np.random.default_rng(42) # Fixed seed for determinism + + for i in range(min(step + 1, len(self.pattern))): + cmd = self.pattern[i] + if cmd.kind == CommandKind.N: + backend.add_nodes([cmd.node], data=cmd.state) + elif cmd.kind == CommandKind.E: + backend.entangle_nodes(cmd.nodes) + elif cmd.kind == CommandKind.M: + # Adaptive measurement (feedforward) + s_signal = sum(results.get(j, 0) for j in cmd.s_domain) if cmd.s_domain else 0 + t_signal = sum(results.get(j, 0) for j in cmd.t_domain) if cmd.t_domain else 0 + + clifford = Clifford.I + if s_signal % 2 == 1: + clifford = Clifford.X @ clifford + if t_signal % 2 == 1: + clifford = Clifford.Z @ clifford + + measurement = cmd.measurement.clifford(clifford) + result = backend.measure(cmd.node, measurement, rng=rng) + results[cmd.node] = result + elif cmd.kind == CommandKind.X: + if cmd.node not in corrections: + corrections[cmd.node] = set() + corrections[cmd.node].add("X") + backend.correct_byproduct(cmd) + elif cmd.kind == CommandKind.Z: + if cmd.node not in corrections: + corrections[cmd.node] = set() + corrections[cmd.node].add("Z") + backend.correct_byproduct(cmd) + + # ---- Topological tracking (independent of simulation) ---- + current_active: set[int] = set(self.pattern.input_nodes) + current_edges: set[tuple[int, int]] = set() + current_measured: set[int] = set() + + for i in range(min(step + 1, len(self.pattern))): + cmd = self.pattern[i] + if cmd.kind == CommandKind.N: + current_active.add(cmd.node) + elif cmd.kind == CommandKind.E: + u, v = cmd.nodes + if u in current_active and v in current_active: + current_edges.add((min(u, v), max(u, v))) + elif cmd.kind == CommandKind.M and cmd.node in current_active: + current_active.remove(cmd.node) + current_measured.add(cmd.node) + + active_nodes = current_active + measured_nodes = current_measured + active_edges = list(current_edges) + + return active_nodes, measured_nodes, active_edges, corrections, results + + def _draw_graph( + self, state: tuple[set[int], set[int], list[tuple[int, int]], dict[int, set[str]], dict[int, int]] + ) -> None: + """Draw nodes and edges onto the graph axes. + + Delegates to :class:`GraphVisualizer` for edge and node rendering, + passing per-node colour overrides to distinguish measured (grey) from + active (blue border) nodes. Labels are drawn locally because they + include dynamic content (measurement results, corrections). + """ + try: + self.ax_graph.clear() + + active_nodes, measured_nodes, active_edges, corrections, results = state + + # Highlight logic + highlight_nodes: set[int] = set() + highlight_edges: set[tuple[int, int]] = set() + if self.current_step > 0: + last_cmd = self.pattern[self.current_step - 1] + if last_cmd.kind in {CommandKind.N, CommandKind.M, CommandKind.C, CommandKind.X, CommandKind.Z}: + highlight_nodes.add(last_cmd.node) # type: ignore[union-attr] + elif last_cmd.kind == CommandKind.E: + highlight_nodes.update(last_cmd.nodes) + highlight_edges.add(last_cmd.nodes) + + # Axis limits + xs = [p[0] for p in self.node_positions.values()] + ys = [p[1] for p in self.node_positions.values()] + if xs and ys: + self.ax_graph.set_xlim(min(xs) - 0.1 * self.node_distance[0], max(xs) + 0.1 * self.node_distance[0]) + self.ax_graph.set_ylim(min(ys) - 0.4, max(ys) + 0.4) + + # Layer separators + if self._l_k is not None: + self._graph_visualizer.draw_layer_separators( + self.ax_graph, self.node_positions, self._l_k, self.node_distance + ) + + # Edges and arrows + edge_path, arrow_path = self._place_paths(self.node_positions) + + edge_colors: dict[tuple[int, int], str] = {} + edge_linewidths: dict[tuple[int, int], float] = {} + for edge in highlight_edges: + e_sorted = (min(edge), max(edge)) + edge_colors[e_sorted] = "black" + edge_linewidths[e_sorted] = 2.0 + + if arrow_path is not None: + self._graph_visualizer.draw_edges_with_routing( + self.ax_graph, + edge_path, + edge_subset=active_edges, + edge_colors=edge_colors, + edge_linewidths=edge_linewidths, + ) + self._graph_visualizer.draw_flow_arrows( + self.ax_graph, self.node_positions, arrow_path, arrow_subset=active_edges + ) + else: + self._graph_visualizer.draw_edges_with_routing( + self.ax_graph, + edge_path, + edge_subset=active_edges, + edge_colors=edge_colors, + edge_linewidths=edge_linewidths, + ) + + # Nodes + node_facecolors: dict[int, str] = {} + node_edgecolors: dict[int, str] = {} + node_alpha: dict[int, float] = {} + node_linewidths: dict[int, float] = {} + for node in highlight_nodes: + if node not in self._graph_visualizer.og.input_nodes: + node_edgecolors[node] = "black" + node_linewidths[node] = 2.0 + + self._graph_visualizer.draw_nodes_role( + self.ax_graph, + self.node_positions, + node_facecolors=node_facecolors, + node_edgecolors=node_edgecolors, + node_alpha=node_alpha, + node_linewidths=node_linewidths, + node_size=self.node_size, + ) + + # Labels + label_offset_pts = (14, -10) + + # Show "XY", "XZ" etc for non-measured nodes + for node in self.node_positions: + if node not in measured_nodes and node in self._graph_visualizer.og.measurements: + meas = self._graph_visualizer.og.measurements[node] + plane = meas.to_plane_or_axis().name + if isinstance(plane, str): + xy = self.node_positions[node] + self.ax_graph.annotate( + plane, + xy=xy, + xytext=label_offset_pts, + textcoords="offset points", + fontsize=9, + zorder=3, + ) + + for node in measured_nodes: + if node in results: + xy = self.node_positions[node] + self.ax_graph.annotate( + f"m={results[node]}", + xy=xy, + xytext=label_offset_pts, + textcoords="offset points", + fontsize=9, + zorder=3, + ) + + for node in active_nodes: + if node in corrections: + lbl = "".join(sorted(corrections[node])) + if lbl: + xy = self.node_positions[node] + self.ax_graph.annotate( + lbl, + xy=xy, + xytext=label_offset_pts, + textcoords="offset points", + fontsize=9, + zorder=3, + ) + + self._graph_visualizer.draw_node_labels( + self.ax_graph, self.node_positions, extra_labels=None, fontsize=self.label_fontsize + ) + + self.ax_graph.axis("off") + + except Exception as e: # noqa: BLE001 + traceback.print_exc() + print(f"Error drawing graph: {e}", file=sys.stderr) + + def _update(self, val: float) -> None: + step = int(val) + if step != self.current_step: + self.current_step = step + + # Fetch state once per tick to feed both visual layers + state = self._update_graph_state(self.current_step) + results = state[4] + + self._draw_command_list(results) + self._draw_graph(state) + self.fig.canvas.draw_idle() + + def _prev_step(self, _event: Any) -> None: + if self.current_step > 0 and self.slider is not None: + self.slider.set_val(self.current_step - 1) + + def _next_step(self, _event: Any) -> None: + if self.current_step < self.total_steps and self.slider is not None: + self.slider.set_val(self.current_step + 1) + + def _on_key(self, event: Any) -> None: + if event.key == "right": + self._next_step(None) + elif event.key == "left": + self._prev_step(None) + + def _on_pick(self, event: Any) -> None: + if isinstance(event.artist, Text): + idx = getattr(event.artist, "index", None) + if idx is not None and self.slider is not None: + self.slider.set_val(idx + 1) # Jump to state AFTER the clicked command diff --git a/tests/test_visualization.py b/tests/test_visualization.py index 030ada64..adf12d6d 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -3,6 +3,7 @@ from pathlib import Path from tempfile import TemporaryDirectory from typing import TYPE_CHECKING +from unittest.mock import MagicMock import matplotlib.pyplot as plt import networkx as nx @@ -10,7 +11,7 @@ from graphix import Circuit, Pattern, command, visualization from graphix.fundamentals import ANGLE_PI -from graphix.measurements import Measurement +from graphix.measurements import Measurement, PauliMeasurement from graphix.opengraph import OpenGraph, OpenGraphError from graphix.visualization import GraphVisualizer @@ -250,3 +251,129 @@ def test_draw_graph_reference(flow_and_not_pauli_presimulate: bool) -> Figure: flow_from_pattern=flow_and_not_pauli_presimulate, node_distance=(0.7, 0.6), show_measurement_planes=True ) return plt.gcf() + + +def test_draw_edges_with_routing_skips_non_subset_edge() -> None: + vis, _pattern = example_visualizer() + ax = MagicMock() + edge_path = {(0, 1): [(0.0, 0.0), (1.0, 0.0)], (1, 2): [(1.0, 0.0), (2.0, 0.0)]} + vis.draw_edges_with_routing(ax, edge_path, edge_subset=[(0, 1)]) + assert ax.plot.call_count == 1 + + +def test_draw_edges_with_routing_color_and_linewidth_overrides() -> None: + vis, _pattern = example_visualizer() + ax = MagicMock() + edge_path = {(0, 1): [(0.0, 0.0), (1.0, 0.0)]} + vis.draw_edges_with_routing( + ax, + edge_path, + edge_colors={(0, 1): "red"}, + edge_linewidths={(0, 1): 2.5}, + ) + assert ax.plot.call_count == 1 + call_kwargs = ax.plot.call_args + assert call_kwargs.kwargs["color"] == "red" + assert call_kwargs.kwargs["linewidth"] == pytest.approx(2.5) + + +def test_draw_flow_arrows_with_subset() -> None: + vis, _pattern = example_visualizer() + ax = MagicMock() + pos = {0: (0.0, 0.0), 1: (1.0, 0.0), 2: (2.0, 0.0)} + arrow_path = {(0, 1): [(0.0, 0.0), (1.0, 0.0)], (1, 2): [(1.0, 0.0), (2.0, 0.0)]} + vis.draw_flow_arrows(ax, pos, arrow_path, arrow_subset=[(0, 1)]) + assert ax.annotate.call_count == 0 # no self-loop, no annotate + + +def test_draw_flow_arrows_self_loop() -> None: + vis, _pattern = example_visualizer() + ax = MagicMock() + pos = {0: (0.5, 0.5)} + loop_path = [(0.5, 0.5), (0.7, 0.7), (0.9, 0.5), (0.7, 0.3), (0.5, 0.5)] + arrow_path = {(0, 0): loop_path} + vis.draw_flow_arrows(ax, pos, arrow_path, show_loop=True) + assert ax.plot.called + assert ax.annotate.called + + +def test_draw_node_labels_auto_fontsize() -> None: + vis, _pattern = example_visualizer() + ax = MagicMock() + pos = {0: (0.0, 0.0), 1: (1.0, 0.0)} + vis.draw_node_labels(ax, pos) + assert ax.text.call_count == 2 + + +def test_draw_node_labels_with_extra_labels() -> None: + vis, _pattern = example_visualizer() + ax = MagicMock() + pos = {0: (0.0, 0.0), 1: (1.0, 0.0)} + vis.draw_node_labels(ax, pos, extra_labels={0: "m=1"}) + calls = ax.text.call_args_list + label_args = [call.args[2] for call in calls] + assert any("\n" in lbl for lbl in label_args) + + +def test_draw_nodes_role_skips_node_not_in_pos() -> None: + vis, _pattern = example_visualizer() + ax = MagicMock() + pos = {99: (0.0, 0.0)} + vis.draw_nodes_role(ax, pos) + assert ax.scatter.call_count == 0 + + +def test_draw_nodes_role_pauli_measurement_lightblue() -> None: + mock_og = MagicMock() + mock_og.graph.nodes.return_value = [0] + mock_og.input_nodes = [] + mock_og.output_nodes = [] + mock_og.measurements = {0: MagicMock(spec=PauliMeasurement)} + + vis = GraphVisualizer(og=mock_og) + ax = MagicMock() + pos = {0: (0.0, 0.0)} + + vis.draw_nodes_role(ax, pos, show_pauli_measurement=True) + call_kwargs = ax.scatter.call_args_list[0].kwargs + assert call_kwargs["facecolors"] == "lightblue" + + +def test_draw_nodes_role_node_alpha_override() -> None: + mock_og = MagicMock() + mock_og.graph.nodes.return_value = [0] + mock_og.input_nodes = [] + mock_og.output_nodes = [] + mock_og.measurements = {} + + vis = GraphVisualizer(og=mock_og) + ax = MagicMock() + pos = {0: (0.0, 0.0)} + + vis.draw_nodes_role(ax, pos, node_alpha={0: 0.3}) + call_kwargs = ax.scatter.call_args_list[0].kwargs + assert call_kwargs["alpha"] == pytest.approx(0.3) + + +def test_draw_nodes_role_node_linewidths_override() -> None: + mock_og = MagicMock() + mock_og.graph.nodes.return_value = [0] + mock_og.input_nodes = [] + mock_og.output_nodes = [] + mock_og.measurements = {} + + vis = GraphVisualizer(og=mock_og) + ax = MagicMock() + pos = {0: (0.0, 0.0)} + + vis.draw_nodes_role(ax, pos, node_linewidths={0: 4.0}) + call_kwargs = ax.scatter.call_args_list[0].kwargs + assert call_kwargs["linewidths"] == pytest.approx(4.0) + + +def test_draw_layer_separators_empty_l_k() -> None: + vis, _pattern = example_visualizer() + ax = MagicMock() + pos = {0: (0.0, 0.0)} + vis.draw_layer_separators(ax, pos, l_k={}) + ax.axvline.assert_not_called() diff --git a/tests/test_visualization_interactive.py b/tests/test_visualization_interactive.py new file mode 100644 index 00000000..d65ee81e --- /dev/null +++ b/tests/test_visualization_interactive.py @@ -0,0 +1,379 @@ +"""Tests for the interactive visualization module.""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest +from matplotlib.text import Text + +from graphix.command import E, M, N, X, Z +from graphix.measurements import Measurement +from graphix.pattern import Pattern +from graphix.visualization import GraphVisualizer +from graphix.visualization_interactive import InteractiveGraphVisualizer + + +class TestInteractiveGraphVisualizer: + @pytest.fixture + def pattern(self) -> Pattern: + """Fixture to provide a standard pattern for testing.""" + pattern = Pattern(input_nodes=[0, 1]) + pattern.add(N(node=0)) + pattern.add(N(node=1)) + pattern.add(N(node=2)) + pattern.add(E(nodes=(0, 1))) + pattern.add(E(nodes=(1, 2))) + pattern.add(M(node=0, measurement=Measurement.XY(0.5), s_domain={1}, t_domain={2})) + pattern.add(M(node=1, measurement=Measurement.XY(0.0), s_domain={2}, t_domain=set())) + pattern.add(X(node=2, domain={0})) + pattern.add(Z(node=2, domain={1})) + return pattern + + def test_init_and_layout(self, pattern: Pattern, mocker: MagicMock) -> None: + """Test initialization, layout scaling, OpenGraph inference, and command_window_size parity.""" + mock_og_class = mocker.patch("graphix.visualization_interactive.OpenGraph") + mock_og_instance = mock_og_class.return_value + mock_og_instance.infer_pauli_measurements.return_value = mock_og_instance + mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("matplotlib.pyplot.figure") + + mock_vis_obj = MagicMock() + mock_visualizer.return_value = mock_vis_obj + expected_pos = {0: (10, 10), 1: (20, 20), 2: (30, 30)} + mock_vis_obj.get_layout.return_value = (expected_pos, {}, {}) + + mock_vis_obj.determine_figsize.return_value = (14.0, 7.0) + viz = InteractiveGraphVisualizer(pattern) + + assert viz.total_steps == len(pattern) + assert viz.enable_simulation + assert len(viz.node_positions) == 3 + mock_vis_obj.get_layout.assert_called_once() + mock_og_instance.infer_pauli_measurements.assert_called_once() + for node, (ex, ey) in expected_pos.items(): + ax, ay = viz.node_positions[node] + assert ax == pytest.approx(ex * viz.node_distance[0]) + assert ay == pytest.approx(ey * viz.node_distance[1]) + + # Narrow figure: max(1, min(7, int(4.0 / 1.8))) = 2 → even → decremented to 1. + mock_vis_obj.determine_figsize.return_value = (4.0, 7.0) + viz_narrow = InteractiveGraphVisualizer(pattern) + assert viz_narrow.command_window_size % 2 == 1 + + def test_graph_state_with_simulation(self, mocker: MagicMock) -> None: + """Test simulation path: topology, adaptive Clifford signals, Z-correction dict init, and label format.""" + # Node 1 uses s_domain={0} and t_domain={0}; with result[0]=1 both signals are odd, + # exercising Clifford.X (s_signal) and Clifford.Z (t_signal) branches. + # Z(node=3) is the first correction on node 3, exercising the dict-initialisation branch. + sim_pattern = Pattern(input_nodes=[0, 1]) + sim_pattern.add(N(node=2)) + sim_pattern.add(N(node=3)) + sim_pattern.add(E(nodes=(0, 2))) + sim_pattern.add(E(nodes=(1, 3))) + sim_pattern.add(M(node=0, measurement=Measurement.XY(0.0))) + sim_pattern.add(M(node=1, measurement=Measurement.XY(0.0), s_domain={0}, t_domain={0})) + sim_pattern.add(X(node=2, domain={0})) + sim_pattern.add(Z(node=3, domain={0})) + + mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("graphix.visualization_interactive.OpenGraph") + mocker.patch("matplotlib.pyplot.figure") + mock_backend = mocker.patch("graphix.visualization_interactive.StatevectorBackend") + + mock_vis_obj = MagicMock() + mock_visualizer.return_value = mock_vis_obj + mock_vis_obj.determine_figsize.return_value = (14.0, 7.0) + mock_place_paths = MagicMock(return_value=({}, {})) + mock_vis_obj.get_layout.return_value = ( + {0: (0, 0), 1: (1, 0), 2: (0, 1), 3: (1, 1)}, + mock_place_paths, + {}, + ) + + backend_instance = mock_backend.return_value + backend_instance.measure.return_value = 1 + + viz = InteractiveGraphVisualizer(sim_pattern, enable_simulation=True) + + active, measured, _, corrections, results = viz._update_graph_state(len(sim_pattern)) + + assert 0 in measured + assert 1 in measured + assert 2 in active + assert 3 in active + assert results[0] == 1 + assert results[1] == 1 + assert backend_instance.measure.call_count == 2 + backend_instance.correct_byproduct.assert_called() + assert "Z" in corrections.get(3, set()) + + viz.ax_graph = MagicMock() + viz.ax_commands = MagicMock() + viz.slider = MagicMock() + viz._update(len(sim_pattern)) + + mock_vis_obj.draw_node_labels.assert_called() + annotate_calls = viz.ax_graph.annotate.call_args_list + label_strings = [call.args[0] for call in annotate_calls if call.args] + assert any("m=" in str(lbl) for lbl in label_strings) + assert not any(str(lbl).endswith(("\n=1", "\n=0")) for lbl in label_strings) + + def test_graph_state_without_simulation(self, pattern: Pattern, mocker: MagicMock) -> None: + """Test that topology tracking works and results are empty when simulation is disabled.""" + mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("graphix.visualization_interactive.OpenGraph") + mocker.patch("matplotlib.pyplot.figure") + + mock_vis_obj = MagicMock() + mock_visualizer.return_value = mock_vis_obj + mock_vis_obj.determine_figsize.return_value = (14.0, 7.0) + mock_place_paths = MagicMock(return_value=({}, {})) + mock_vis_obj.get_layout.return_value = ({0: (0, 0), 1: (1, 0), 2: (0, 1)}, mock_place_paths, {}) + + viz = InteractiveGraphVisualizer(pattern, enable_simulation=False) + + active, measured, _, _, results = viz._update_graph_state(len(pattern)) + + assert 0 in measured + assert 1 in measured + assert 2 in active + assert results == {} + + viz.ax_graph = MagicMock() + viz.ax_commands = MagicMock() + viz.slider = MagicMock() + viz._update(len(pattern)) + + assert viz.ax_commands.text.call_count > 0 + mock_vis_obj.draw_node_labels.assert_called() + + def test_visualize(self, pattern: Pattern, mocker: MagicMock) -> None: + """Test the main visualize method (smoke test).""" + mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("graphix.visualization_interactive.OpenGraph") + mocker.patch("matplotlib.pyplot.figure") + mock_show = mocker.patch("matplotlib.pyplot.show") + mocker.patch("graphix.visualization_interactive.Slider") + + mock_vis_obj = MagicMock() + mock_visualizer.return_value = mock_vis_obj + mock_vis_obj.determine_figsize.return_value = (14.0, 7.0) + mock_place_paths = MagicMock(return_value=({}, {})) + mock_vis_obj.get_layout.return_value = ({0: (0, 0), 1: (1, 0), 2: (0, 1)}, mock_place_paths, {}) + + viz = InteractiveGraphVisualizer(pattern) + viz.visualize() + + mock_show.assert_called_once() + assert viz.ax_commands is not None + assert viz.ax_graph is not None + assert viz.slider is not None + assert viz.btn_next is not None + assert viz.btn_prev is not None + + def test_navigation_and_events(self, pattern: Pattern, mocker: MagicMock) -> None: + """Test step navigation boundaries, keyboard dispatch, and pick-event slider sync.""" + mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("graphix.visualization_interactive.OpenGraph") + mocker.patch("matplotlib.pyplot.figure") + + mock_vis_obj = MagicMock() + mock_visualizer.return_value = mock_vis_obj + mock_vis_obj.determine_figsize.return_value = (14.0, 7.0) + mock_place_paths = MagicMock(return_value=({}, {})) + mock_vis_obj.get_layout.return_value = ({0: (0, 0), 1: (1, 0), 2: (0, 1)}, mock_place_paths, {}) + + viz = InteractiveGraphVisualizer(pattern) + viz.slider = MagicMock() + viz.total_steps = 10 + viz.current_step = 5 + + viz._prev_step(None) + viz.slider.set_val.assert_called_with(4) + viz._next_step(None) + viz.slider.set_val.assert_called_with(6) + + viz.current_step = 0 + viz.slider.reset_mock() + viz._prev_step(None) + viz.slider.set_val.assert_not_called() + + viz.current_step = 10 + viz.slider.reset_mock() + viz._next_step(None) + viz.slider.set_val.assert_not_called() + + key_event = MagicMock() + key_event.key = "right" + mock_next = mocker.patch.object(viz, "_next_step") + viz._on_key(key_event) + mock_next.assert_called_once() + + key_event.key = "left" + mock_prev = mocker.patch.object(viz, "_prev_step") + viz._on_key(key_event) + mock_prev.assert_called_once() + + viz.ax_graph = MagicMock() + viz.ax_commands = MagicMock() + viz.total_steps = len(pattern) # restore after manual override + viz.current_step = 0 + viz._update(len(pattern)) + mock_artist = MagicMock(spec=Text) + mock_artist.index = 5 + pick_event = MagicMock() + pick_event.artist = mock_artist + viz.slider.reset_mock() + viz._on_pick(pick_event) + viz.slider.set_val.assert_called_with(6) + + def test_draw_graph_rendering(self, pattern: Pattern, mocker: MagicMock) -> None: + """Test edge and node delegation to GraphVisualizer, and silent exception handling.""" + mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("graphix.visualization_interactive.OpenGraph") + mocker.patch("matplotlib.pyplot.figure") + + mock_vis_obj = MagicMock() + mock_visualizer.return_value = mock_vis_obj + mock_vis_obj.determine_figsize.return_value = (14.0, 7.0) + mock_place_paths = MagicMock(return_value=({}, {})) + mock_vis_obj.get_layout.return_value = ({0: (0, 0), 1: (1, 0), 2: (0, 1)}, mock_place_paths, {}) + + viz = InteractiveGraphVisualizer(pattern) + viz.ax_graph = MagicMock() + viz.slider = MagicMock() + + viz._update(5) + mock_vis_obj.draw_edges_with_routing.assert_called() + assert "edge_colors" in mock_vis_obj.draw_edges_with_routing.call_args.kwargs + + viz._update(len(pattern)) + mock_vis_obj.draw_nodes_role.assert_called() + assert "node_facecolors" in mock_vis_obj.draw_nodes_role.call_args.kwargs + assert "node_edgecolors" in mock_vis_obj.draw_nodes_role.call_args.kwargs + + viz.ax_graph.clear.side_effect = ValueError("boom") + mocker.patch("traceback.print_exc") + mock_state: tuple[set[int], set[int], list[tuple[int, int]], dict[int, set[str]], dict[int, int]] = ( + set(), set(), [], {}, {} + ) + viz._draw_graph(mock_state) + + def test_draw_graph_no_flow_and_plane_annotation(self, mocker: MagicMock) -> None: + """Test the no-flow else branch and plane label annotation for active nodes.""" + mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("graphix.visualization_interactive.OpenGraph") + mocker.patch("matplotlib.pyplot.figure") + + mock_vis_obj = MagicMock() + mock_visualizer.return_value = mock_vis_obj + mock_vis_obj.determine_figsize.return_value = (14.0, 7.0) + mock_place_paths = MagicMock(return_value=({(0, 1): [(0.0, 0.0), (1.0, 0.0)]}, None)) + mock_vis_obj.get_layout.return_value = ( + {0: (0.0, 0.0), 1: (1.0, 0.0), 2: (0.0, 1.0)}, + mock_place_paths, + None, + ) + mock_meas = MagicMock() + mock_meas.to_plane_or_axis.return_value.name = "XY" + mock_vis_obj.og.measurements = {0: mock_meas, 1: mock_meas} + mock_vis_obj.og.input_nodes = [] + + no_flow_pattern = Pattern( + input_nodes=[0, 1], + cmds=[N(node=0), N(node=1), E(nodes=(0, 1)), M(node=0), M(node=1)], + ) + viz = InteractiveGraphVisualizer(no_flow_pattern, enable_simulation=False) + viz.ax_graph = MagicMock() + viz.ax_commands = MagicMock() + viz.slider = MagicMock() + viz.node_positions = {0: (0.0, 0.0), 1: (1.0, 0.0), 2: (0.0, 1.0)} + + viz._update(len(no_flow_pattern)) + mock_vis_obj.draw_edges_with_routing.assert_called() + mock_vis_obj.draw_flow_arrows.assert_not_called() + + viz.current_step = 0 + viz.ax_graph.reset_mock() + viz._update(1) + annotate_calls = viz.ax_graph.annotate.call_args_list + plane_labels = [c.args[0] for c in annotate_calls if c.args and isinstance(c.args[0], str)] + assert any(lbl == "XY" for lbl in plane_labels) + + def test_draw_command_list_early_return(self, mocker: MagicMock) -> None: + """Test that _draw_command_list returns without drawing when the window has no visible commands.""" + mock_visualizer = mocker.patch("graphix.visualization_interactive.GraphVisualizer") + mocker.patch("graphix.visualization_interactive.OpenGraph") + mocker.patch("matplotlib.pyplot.figure") + + mock_vis_obj = MagicMock() + mock_visualizer.return_value = mock_vis_obj + mock_vis_obj.determine_figsize.return_value = (14.0, 7.0) + mock_vis_obj.get_layout.return_value = ({}, {}, {}) + + empty_pattern = Pattern(input_nodes=[]) + viz = InteractiveGraphVisualizer(empty_pattern, enable_simulation=False) + viz.ax_commands = MagicMock() + viz.current_step = 0 + + viz._draw_command_list({}) + viz.ax_commands.text.assert_not_called() + + +class TestGraphVisualizerSharedAPI: + """Tests for the shared drawing API exposed by GraphVisualizer.""" + + def test_get_label_fontsize(self) -> None: + """Test font-size computation for small, large, and custom-base node numbers.""" + assert GraphVisualizer.get_label_fontsize(0) == 12 + assert GraphVisualizer.get_label_fontsize(99) == 12 + large = GraphVisualizer.get_label_fontsize(100) + assert 7 <= large < 12 + assert GraphVisualizer.get_label_fontsize(0, base_size=10) == 10 + custom = GraphVisualizer.get_label_fontsize(1000, base_size=10) + assert 7 <= custom < 10 + + def test_draw_nodes_role_with_overrides(self) -> None: + """Test draw_nodes_role applies per-node colour overrides.""" + mock_og = MagicMock() + mock_og.graph.nodes.return_value = [0, 1, 2] + mock_og.input_nodes = [0] + mock_og.output_nodes = [2] + mock_og.measurements = {0: MagicMock(), 1: MagicMock(), 2: MagicMock()} + + vis = GraphVisualizer(og=mock_og) + ax = MagicMock() + pos = {0: (0.0, 0.0), 1: (1.0, 0.0), 2: (2.0, 0.0)} + + vis.draw_nodes_role( + ax, + pos, + node_facecolors={0: "yellow", 1: "pink"}, + node_edgecolors={0: "green"}, + ) + + scatter_calls = ax.scatter.call_args_list + assert len(scatter_calls) == 3 + assert scatter_calls[0].kwargs["facecolors"] == "yellow" + assert scatter_calls[0].kwargs["edgecolors"] == "green" + assert scatter_calls[1].kwargs["facecolors"] == "pink" + assert scatter_calls[1].kwargs["edgecolors"] == "black" + assert scatter_calls[2].kwargs["facecolors"] == "lightgray" + assert scatter_calls[2].kwargs["edgecolors"] == "black" + + def test_draw_edges(self) -> None: + """Test draw_edges with and without an edge subset.""" + mock_og = MagicMock() + mock_og.graph.edges.return_value = [(0, 1), (1, 2), (2, 3)] + + vis = GraphVisualizer(og=mock_og) + ax = MagicMock() + pos = {0: (0.0, 0.0), 1: (1.0, 0.0), 2: (2.0, 0.0), 3: (3.0, 0.0)} + + vis.draw_edges(ax, pos, edge_subset=[(0, 1), (2, 3)]) + assert ax.plot.call_count == 2 + + ax.reset_mock() + vis.draw_edges(ax, pos) + assert ax.plot.call_count == 3