Source code for general_python.lattices.visualization.plotting

"""
Matplotlib-based visualisation helpers for lattice objects.

This module provides a comprehensive set of plotting utilities for visualizing
lattice structures, reciprocal space vectors, Brillouin zones, and other
geometric properties. It supports 1D, 2D, and 3D lattices.

Functions:
    - plot_real_space           : Scatter plot of real-space sites.
    - plot_reciprocal_space     : Scatter plot of reciprocal lattice vectors.
    - plot_brillouin_zone       : Visualization of the Brillouin Zone.
    - plot_lattice_structure    : Detailed connectivity plot with boundaries.

Classes:
    - LatticePlotter            : Convenience wrapper for lattice plotting.

----------------------------------------------------------------------
File    : general_python/lattices/visualization/plotting.py
Author  : Maksymilian Kliczkowski
Date    : 2025-02-01
----------------------------------------------------------------------
"""

from    __future__          import annotations
import  numpy               as np
import  matplotlib.pyplot   as plt

from    collections         import defaultdict
from    dataclasses         import dataclass
from    typing              import Optional, Tuple, List, Set, Dict, Union, Any, Iterable

from    matplotlib.axes     import Axes
from    matplotlib.figure   import Figure

# Optional dependencies for 3D and ConvexHull
try:
    from mpl_toolkits.mplot3d           import Axes3D
    from mpl_toolkits.mplot3d.art3d     import Poly3DCollection
except ImportError:
    Poly3DCollection = None

try:
    from scipy.spatial                  import ConvexHull
except ImportError:
    ConvexHull = None

from ..lattice import Lattice

# ==============================================================================
# Helpers
# ==============================================================================

def _ensure_numpy(vectors) -> np.ndarray:
    """ Ensure input is a 2D numpy array of float vectors. """
    arr = np.asarray(vectors, dtype=float)
    if arr.ndim == 1:
        arr = arr.reshape(-1, arr.shape[0])
    if arr.ndim != 2:
        raise ValueError(f"Expected a 2D array of vectors, got shape {arr.shape!r}.")
    return arr

def _init_axes(ax: Optional[Axes], dim: int, projection: Optional[str] = None) -> Tuple[Figure, Axes]:
    """ Initialize figure and axes if not provided. """
    if ax is not None:
        return ax.figure, ax
    
    if dim >= 3:
        fig = plt.figure()
        ax  = fig.add_subplot(111, projection=projection or "3d")
    else:
        fig, ax = plt.subplots()
        
    return fig, ax

def _annotate_indices(ax: Axes, coords: np.ndarray, *, zorder: int = 5, color: str = 'k', fontsize: int = 8, padding: float = 0.0) -> None:
    """ Annotate sites with their indices. """
    for idx, point in enumerate(coords):
        
        true_point = point + padding * np.sign(point) if padding != 0 else point
        
        if true_point.size >= 3:
            ax.text(true_point[0], true_point[1], true_point[2], str(idx), zorder=zorder, color=color, fontsize=fontsize)
        elif true_point.size == 2:
            ax.text(true_point[0], true_point[1], str(idx), zorder=zorder, color=color, fontsize=fontsize)
        else:
            ax.text(true_point[0], 0.0, str(idx), zorder=zorder, color=color, fontsize=fontsize)

def _finalise_figure(fig: Figure, *, top_padding: float = 0.88) -> None:
    """ Apply layout adjustments. """
    try:
        # Avoid subplots_adjust if a layout engine (e.g. constrained_layout) is active
        if hasattr(fig, 'get_layout_engine') and fig.get_layout_engine() is not None:
            return
            
        fig.subplots_adjust(top=top_padding)
        # Tight layout can be problematic with 3D axes in some mpl versions
        # fig.tight_layout() 
    except Exception:
        pass

def _apply_planar_aspect(axis: Axes, *, fix_aspect: bool = True) -> None:
    """Apply or release equal aspect for 2D axes."""
    axis.set_aspect("equal" if fix_aspect else "auto", adjustable="box")

def _apply_spatial_limits(
    axis        : Axes, 
    coords      : np.ndarray, 
    dim         : int, 
    show_axes   : bool, 
    margin      : float = 0.08,
    labels      : Optional[Tuple[str, ...]] = None,
    fix_aspect  : bool = True,
) -> None:
    """ 
    Uniformly apply spatial limits with padding and aspect ratio. 
    Ensures that nodes and annotations near the edges are not clipped.
    """
    if coords.size == 0:
        return

    # Calculate data bounds
    mins    = coords.min(axis=0)
    maxs    = coords.max(axis=0)
    
    # Calculate span-based padding
    spans   = maxs - mins
    
    # 1D case
    if dim == 1:
        span    = spans[0]
        pad     = margin * (span if span > 0 else 1.0)
        axis.set_xlim(mins[0] - pad, maxs[0] + pad)
        axis.set_ylim(-0.5, 0.5)
        
    # 2D case
    elif dim == 2:
        diag    = np.sqrt(spans[0]**2 + spans[1]**2)
        pad     = margin * (diag if diag > 0 else 1.0)
        
        axis.set_xlim(mins[0] - pad, maxs[0] + pad)
        axis.set_ylim(mins[1] - pad, maxs[1] + pad)
        _apply_planar_aspect(axis, fix_aspect=fix_aspect)
        
    # 3D case
    else:
        diag    = np.sqrt(np.sum(spans[:3]**2))
        pad     = margin * (diag if diag > 0 else 1.0)
        
        axis.set_xlim(mins[0] - pad, maxs[0] + pad)
        axis.set_ylim(mins[1] - pad, maxs[1] + pad)
        axis.set_zlim(mins[2] - pad, maxs[2] + pad)

    # Handle axis visibility
    if not show_axes:
        if dim == 1:
            axis.get_yaxis().set_visible(False)
            for sp in axis.spines.values(): sp.set_visible(False)
        else:
            # Instead of set_axis_off(), we hide spines and ticks to keep 
            # the axes object useful for layout engines (like constrained_layout)
            axis.set_xticks([])
            axis.set_yticks([])
            if hasattr(axis, 'set_zticks'):
                axis.set_zticks([])
            for sp in axis.spines.values():
                sp.set_visible(False)
            axis.patch.set_alpha(0.0) # Hide background patch
    else:
        lbls = labels or ("x", "y", "z")
        axis.set_xlabel(lbls[0])
        if dim >= 2: axis.set_ylabel(lbls[1])
        if dim >= 3: axis.set_zlabel(lbls[2])

# ==============================================================================
# Plotting Functions
# ==============================================================================

[docs] def plot_real_space( lattice : Lattice, *, ax : Optional[Axes] = None, show_indices : bool = False, show_axes : bool = True, color : str = "C0", marker : str = "o", figsize : Optional[Tuple[float, float]] = None, fix_aspect : bool = True, title : Optional[str] = None, title_kwargs : Optional[Dict[str, object]] = None, tight_layout : bool = True, elev : Optional[float] = None, azim : Optional[float] = None, **scatter_kwargs, ) -> Tuple[Figure, Axes]: r""" Scatter-plot of real-space lattice vectors. Parameters ---------- lattice : Lattice The lattice object to plot. ax : Axes, optional Matplotlib axes to plot on. If None, a new figure is created. show_indices : bool, default=False If True, annotate each site with its index. show_axes : bool, default=True If False, hides the coordinate axes. color : str, default="C0" Color of the site markers. marker : str, default="o" Marker style. figsize : tuple, optional Figure size in inches (width, height). fix_aspect : bool, default=True If True, preserve equal axis scaling in 2D plots. Set to ``False`` to let the requested ``figsize`` control the on-screen aspect. title : str, optional Title of the plot. elev, azim : float, optional Elevation and azimuth angles for 3D plots. **scatter_kwargs Additional arguments passed to `ax.scatter`. Returns ------- fig, ax : Tuple[Figure, Axes] The figure and axes objects. """ coords = _ensure_numpy(lattice.rvectors) target_dim = lattice.dim if lattice.dim else coords.shape[1] dim = max(1, min(coords.shape[1], target_dim, 3)) coords = coords[:, :dim] fig, axis = _init_axes(ax, dim) if figsize is not None and axis is fig.axes[0]: fig.set_size_inches(*figsize, forward=True) if dim == 3 and (elev is not None or azim is not None): current_elev = elev if elev is not None else getattr(axis, "elev", None) current_azim = azim if azim is not None else getattr(axis, "azim", None) axis.view_init(elev=current_elev, azim=current_azim) # Plotting based on dimension if dim == 1: axis.scatter(coords[:, 0], np.zeros_like(coords[:, 0]), color=color, marker=marker, **scatter_kwargs) elif dim == 2: axis.scatter(coords[:, 0], coords[:, 1], color=color, marker=marker, **scatter_kwargs) else: axis.scatter(coords[:, 0], coords[:, 1], coords[:, 2], color=color, marker=marker, **scatter_kwargs) # Spatial Limits & Visibility _apply_spatial_limits(axis, coords, dim, show_axes, fix_aspect=fix_aspect) # Title if title: kw = {"pad": 12} if title_kwargs: kw.update(title_kwargs) axis.set_title(title, **kw) # Annotations if show_indices: _annotate_indices(axis, coords) if tight_layout: _finalise_figure(fig) return fig, axis
# ------------------------------------------------------------------------------- #! Reciprocal Space Plotter # -------------------------------------------------------------------------------
[docs] def plot_reciprocal_space( lattice : Lattice, *, ax : Optional[Axes] = None, show_indices : bool = False, show_axes : bool = True, color : str = "C1", marker : str = "o", figsize : Optional[Tuple[float, float]] = None, fix_aspect : bool = True, title : Optional[str] = None, title_kwargs : Optional[Dict[str, object]] = None, tight_layout : bool = True, elev : Optional[float] = None, azim : Optional[float] = None, # extension extend_kpoints : bool = False, extend_copies : Union[int, Iterable[int]] = 2, extend_tol : float = 1e-10, **scatter_kwargs, ) -> Tuple[Figure, Axes]: """ Scatter-plot of reciprocal lattice vectors (k-points). Parameters mirror :func:`plot_real_space` -------------------------------------------------------------------------- lattice : Lattice The lattice object to plot. ax : Axes, optional Matplotlib axes to plot on. If None, a new figure is created. show_indices : bool, default=False If True, annotate each k-point with its index. show_axes : bool, default=True If False, hides the coordinate axes. color : str, default="C1" Color of the k-point markers. marker : str, default="o" Marker style. figsize : tuple, optional Figure size in inches (width, height). fix_aspect : bool, default=True If True, preserve equal axis scaling in 2D plots. Set to ``False`` to let the requested ``figsize`` control the on-screen aspect. title : str, optional Title of the plot. elev, azim : float, optional Elevation and azimuth angles for 3D plots. extend_kpoints : bool, default=False If True, draw translated reciprocal-space copies around the original mesh. extend_copies : int or iterable of int, default=2 Number of copies per reciprocal direction used when ``extend_kpoints=True``. Scalars are applied to all active reciprocal directions. extend_tol : float, default=1e-10 Tolerance used to identify which extended points are already present in the original reciprocal mesh. **scatter_kwargs Include: - point_edgecolor: Color of the marker edges (default "white"). - point_zorder: Z-order for the scatter points (default 5). - color_extended: Color for translated copies (default "C2"). - edgecolor_extended: Edge color for translated copies (default "gray"). - marker_extended: Marker for translated copies (default ``marker``). - Any other valid arguments for `ax.scatter`. """ coords = _ensure_numpy(lattice.kvectors) target_dim = lattice.dim if lattice.dim else coords.shape[1] dim = max(1, min(coords.shape[1], target_dim, 3)) coords = coords[:, :dim] fig, axis = _init_axes(ax, dim) if figsize is not None and axis is fig.axes[0]: fig.set_size_inches(*figsize, forward=True) # Set 3D view angles if specified if dim == 3 and (elev is not None or azim is not None): current_elev = elev if elev is not None else getattr(axis, "elev", None) current_azim = azim if azim is not None else getattr(axis, "azim", None) axis.view_init(elev=current_elev, azim=current_azim) # Scatter plot of k-points -> a simple scatter with appropriate axis labels point_edgecolor = scatter_kwargs.pop("edgecolor", "white") point_zorder = scatter_kwargs.pop("zorder", 5) point_color = color if color else scatter_kwargs.get("color", "C1") color_extended = scatter_kwargs.pop("color_extended", "C2") edgecolor_extended = scatter_kwargs.pop("edgecolor_extended", "gray") marker_extended = scatter_kwargs.pop("marker_extended", marker) if dim == 1: axis.scatter(coords[:, 0], np.zeros_like(coords[:, 0]), color=point_color, marker=marker, edgecolor=point_edgecolor, zorder=point_zorder, **scatter_kwargs) elif dim == 2: axis.scatter(coords[:, 0], coords[:, 1], color=point_color, marker=marker, edgecolor=point_edgecolor, zorder=point_zorder, **scatter_kwargs) else: axis.scatter(coords[:, 0], coords[:, 1], coords[:, 2], color=point_color, marker=marker, edgecolor=point_edgecolor, zorder=point_zorder, **scatter_kwargs) plotted_coords = coords # Set the title if necessary if title: kw = {"pad": 12} if title_kwargs: kw.update(title_kwargs) axis.set_title(title, **kw) if show_indices: _annotate_indices(axis, coords) if extend_kpoints: active_dim = max(1, min(dim, 3)) if np.isscalar(extend_copies): copy_spec = int(extend_copies) else: copy_values = [int(copy) for copy in extend_copies] if len(copy_values) < active_dim: raise ValueError("extend_copies must provide at least one value per plotted reciprocal direction") copy_spec = tuple(copy_values[:active_dim]) extended_k_points, _ = lattice.wigner_seitz_extend(k_points=coords, copies=copy_spec) # Compare by rounded row keys to avoid quadratic allclose scans. scale = max(float(extend_tol), np.finfo(float).eps) original_keys = {tuple(np.rint(row / scale).astype(np.int64)) for row in np.asarray(coords, dtype=float)} extended_keys = [tuple(np.rint(row / scale).astype(np.int64)) for row in np.asarray(extended_k_points, dtype=float)] extended_k_points_mask = np.array([key not in original_keys for key in extended_keys], dtype=bool) # plot other k-points in a different style if np.any(extended_k_points_mask): extended_coords = extended_k_points[extended_k_points_mask] plotted_coords = np.vstack((coords, extended_coords)) if dim == 1: axis.scatter(extended_coords[:, 0], np.zeros_like(extended_coords[:, 0]), color=color_extended, marker=marker_extended, edgecolor=edgecolor_extended, zorder=point_zorder-1, **scatter_kwargs) elif dim == 2: axis.scatter(extended_coords[:, 0], extended_coords[:, 1], color=color_extended, marker=marker_extended, edgecolor=edgecolor_extended, zorder=point_zorder-1, **scatter_kwargs) else: axis.scatter(extended_coords[:, 0], extended_coords[:, 1], extended_coords[:, 2], color=color_extended, marker=marker_extended, edgecolor=edgecolor_extended, zorder=point_zorder-1, **scatter_kwargs) # Spatial Limits & Visibility k_labels = (r"$k_x$", r"$k_y$", r"$k_z$") _apply_spatial_limits(axis, plotted_coords, dim, show_axes, labels=k_labels, fix_aspect=fix_aspect) if tight_layout: _finalise_figure(fig) return fig, axis
# ------------------------------------------------------------------------------- #! Brillouin Zone Plotter # ------------------------------------------------------------------------------- def _draw_bz_region(axis: Axes, points: np.ndarray, *, dim : int, lattice : Optional[Lattice] = None, offset : Optional[np.ndarray] = None, shells : int = 2, facecolor : str, edgecolor : str, alpha : float, linewidth : float = 1.5, fix_aspect : bool = True, show_points : bool = True, point_kwargs : Optional[Dict[str, Any]] = None, zorder : float = 0.0, **kwargs, ) -> None: """ Draw a Brillouin-zone region from sampled boundary points. It allows to visualize the Wigner-Seitz cell of the reciprocal lattice, which is the Brillouin zone, by plotting the region defined by the given boundary points. Parameters ---------- axis : Axes The matplotlib axes to draw on. points : array-like An array of shape (N, D) containing the coordinates of the boundary points. dim : int The dimensionality of the points (1, 2, or 3). lattice : Lattice, optional If provided, the lattice's Wigner-Seitz mask will be used to draw the Brillouin zone region. This can provide a more accurate representation of the BZ shape. offset : array-like, optional An optional offset to apply to the points when using the lattice's Wigner-Seitz mask. This can be used to visualize the BZ region around a specific k-point. shells : int, default=2 When using the lattice's Wigner-Seitz mask, this parameter controls how many shells of reciprocal lattice vectors to consider when determining the mask. A larger number of shells can provide a more accurate BZ shape but may increase computation time. facecolor : str The color to fill the Brillouin zone region. edgecolor : str The color to use for the edges of the Brillouin zone region. alpha : float The transparency level for the filled region (0.0 transparent, 1.0 opaque). linewidth : float, default=1.5 The width of the edges of the Brillouin zone region. show_points : bool, default=True If True, the original boundary points will be plotted on top of the filled region. point_kwargs : dict, optional Additional keyword arguments to pass to the scatter function when plotting the boundary points (e.g., marker style, size). zorder : float, default=0.0 The z-order for the filled region and edges. Points will be plotted at zorder + 0.2 to ensure they are on top. """ pts = np.asarray(points, dtype=float) point_kwargs = {} if point_kwargs is None else dict(point_kwargs) if dim == 1: x_min, x_max = pts[:, 0].min(), pts[:, 0].max() axis.axvspan(x_min, x_max, ymin=0.25, ymax=0.75, facecolor=facecolor, alpha=alpha, zorder=zorder) axis.plot([x_min, x_max], [0.5, 0.5], color=edgecolor, linewidth=linewidth, zorder=zorder + 0.1, **kwargs) if show_points: axis.scatter([x_min, x_max], [0.5, 0.5], color=edgecolor, edgecolor="white", zorder=zorder + 0.2, **point_kwargs) axis.set_ylim(0, 1) axis.set_yticks([]) axis.set_xlabel("k") return if dim == 2: if lattice is not None: # Use the lattice's Wigner-Seitz mask to draw the Brillouin zone region b1 = np.asarray(lattice.k1, dtype=float).ravel()[:2] b2 = np.asarray(lattice.k2, dtype=float).ravel()[:2] # pad to ensure we cover the entire region even if the Wigner-Seitz cell is slightly larger than the sampled points pad = 1.35 kmax = max(np.linalg.norm(b1), np.linalg.norm(b2)) * pad # Generate a grid of points around the origin (or offset) to evaluate the Wigner-Seitz mask gx = np.linspace(-kmax, kmax, 500) gy = np.linspace(-kmax, kmax, 500) # Apply offset if provided if offset is not None: shift = np.asarray(offset, dtype=float).ravel()[:2] gx += shift[0] gy += shift[1] else: shift = np.zeros(2, dtype=float) GX, GY = np.meshgrid(gx, gy) mask = lattice.wigner_seitz_mask(GX - shift[0], GY - shift[1], shells=shells) # Plot the filled contour for the Wigner-Seitz cell and its edges axis.contourf(GX, GY, mask.astype(float), levels=[0.5, 1.5], colors=[facecolor], alpha=alpha, zorder=zorder) # The contour for the edge is plotted with a slightly higher zorder to ensure it appears on top of the filled region axis.contour(GX, GY, mask.astype(float), levels=[0.5], colors=[edgecolor], linewidths=linewidth, zorder=zorder + 0.1) else: # If no lattice is provided, we can attempt to draw a convex hull around the points to represent the BZ region. # This is a fallback and may not accurately capture the true BZ shape, especially if # the points are not sampled densely or if the BZ has a complex shape. polygon = None if ConvexHull is not None: try: hull = ConvexHull(pts[:, :2]) polygon = pts[hull.vertices, :2] except Exception: pass if polygon is None: x_min, y_min = pts[:, :2].min(axis=0) x_max, y_max = pts[:, :2].max(axis=0) polygon = np.array([ [x_min, y_min], [x_max, y_min], [x_max, y_max], [x_min, y_max], ]) axis.fill(*polygon.T, facecolor=facecolor, alpha=alpha, edgecolor=edgecolor, linewidth=linewidth, zorder=zorder) axis.plot(*polygon.T, color=edgecolor, linewidth=linewidth, zorder=zorder + 0.1) axis.plot([polygon[-1, 0], polygon[0, 0]], [polygon[-1, 1], polygon[0, 1]], color=edgecolor, linewidth=linewidth, zorder=zorder + 0.1) if show_points: axis.scatter(pts[:, 0], pts[:, 1], color=edgecolor, edgecolor="white", zorder=zorder + 0.2, **point_kwargs) _apply_planar_aspect(axis, fix_aspect=fix_aspect) axis.set_xlabel(r"$k_x$") axis.set_ylabel(r"$k_y$") return # --------------- # For 3D, we attempt to create a convex hull of the points to represent the BZ region. # --------------- if Poly3DCollection is None: raise RuntimeError("3D plotting support requires mpl_toolkits.mplot3d.") faces = None if ConvexHull is not None: try: hull = ConvexHull(pts[:, :3]) faces = [pts[simplex, :3] for simplex in hull.simplices] except Exception: pass if faces is None: mins = pts[:, :3].min(axis=0) maxs = pts[:, :3].max(axis=0) corners = np.array([ [mins[0], mins[1], mins[2]], [maxs[0], mins[1], mins[2]], [maxs[0], maxs[1], mins[2]], [mins[0], maxs[1], mins[2]], [mins[0], mins[1], maxs[2]], [maxs[0], mins[1], maxs[2]], [maxs[0], maxs[1], maxs[2]], [mins[0], maxs[1], maxs[2]], ]) faces = [ corners[[0, 1, 2, 3]], corners[[4, 5, 6, 7]], corners[[0, 1, 5, 4]], corners[[2, 3, 7, 6]], corners[[1, 2, 6, 5]], corners[[3, 0, 4, 7]], ] collection = Poly3DCollection( faces, facecolor=facecolor, edgecolor=edgecolor, alpha=alpha, linewidths=linewidth, zorder=zorder ) axis.add_collection3d(collection) if show_points: axis.scatter(pts[:, 0], pts[:, 1], pts[:, 2], color=edgecolor, edgecolor="white", zorder=zorder + 0.2, **point_kwargs) axis.set_xlabel(r"$k_x$") axis.set_ylabel(r"$k_y$") axis.set_zlabel(r"$k_z$") def _plot_1d_bz(axis: Axes, bounds: Tuple[float, float], *, facecolor: str, alpha: float, **kwargs) -> None: ''' Plot a 1D Brillouin Zone as a horizontal band. ''' _draw_bz_region(axis, points=np.array(bounds).reshape(1, 2), dim=1, facecolor=facecolor, alpha=alpha, **kwargs) def _plot_2d_bz(axis: Axes, points: np.ndarray, *, facecolor: str, edgecolor: str, alpha: float, **kwargs) -> None: ''' Plot a 2D Brillouin Zone as a filled polygon. ''' _draw_bz_region(axis, points, dim=2, facecolor=facecolor, edgecolor=edgecolor, alpha=alpha, **kwargs) def _plot_3d_bz(axis: Axes, points: np.ndarray, *, facecolor: str, edgecolor: str, alpha: float, elev: float = 30.0, azim: float = 45.0, **kwargs) -> None: ''' Plot a 3D Brillouin Zone as a convex hull polyhedron. ''' _draw_bz_region(axis, points, dim=3, facecolor=facecolor, edgecolor=edgecolor, alpha=alpha, **kwargs) axis.view_init(elev=elev, azim=azim)
[docs] def plot_brillouin_zone( lattice : Lattice, *, ax : Optional[Axes] = None, facecolor : str = "tab:blue", edgecolor : str = "black", alpha : float = 0.25, figsize : Optional[Tuple[float, float]] = None, fix_aspect : bool = True, title : Optional[str] = None, title_kwargs : Optional[Dict[str, object]] = None, tight_layout : bool = True, elev : Optional[float] = None, azim : Optional[float] = None) -> Tuple[Figure, Axes]: """ Plot the Brillouin Zone approximation based on sampled k-vectors. Parameters ---------- lattice : Lattice The lattice object containing k-vectors. ax : Axes, optional Matplotlib axes to plot on. If None, a new figure is created. facecolor : str, default="tab:blue" Color to fill the Brillouin Zone area. edgecolor : str, default="black" Color for the Brillouin Zone boundary. alpha : float, default=0.25 Transparency level for the Brillouin Zone fill. figsize : tuple, optional Figure size in inches (width, height). fix_aspect : bool, default=True If True, preserve equal axis scaling in 2D plots. Set to ``False`` to let the requested ``figsize`` control the on-screen aspect. title : str, optional Title of the plot. elev, azim : float, optional Elevation and azimuth angles for 3D plots. """ coords = _ensure_numpy(lattice.kvectors) target_dim = lattice.dim if lattice.dim else coords.shape[1] dim = max(1, min(coords.shape[1], target_dim, 3)) coords = coords[:, :dim] fig, axis = _init_axes(ax, dim) if figsize is not None and axis is fig.axes[0]: fig.set_size_inches(*figsize, forward=True) if dim == 1: _plot_1d_bz(axis, (coords[:, 0].min(), coords[:, 0].max()), facecolor=facecolor, alpha=alpha) elif dim == 2: _plot_2d_bz(axis, coords[:, :2], facecolor=facecolor, edgecolor=edgecolor, alpha=alpha, fix_aspect=fix_aspect) else: _plot_3d_bz(axis, coords[:, :3], facecolor=facecolor, edgecolor=edgecolor, alpha=alpha, elev=elev, azim=azim) if title: kw = {"pad": 12} if title_kwargs: kw.update(title_kwargs) axis.set_title(title, **kw) if tight_layout: _finalise_figure(fig) return fig, axis
# ============================================================================== #! Structural Plotting Helpers # ============================================================================== def _gather_nn_edges(lattice: Lattice) -> List[Tuple[int, int]]: """ Extract nearest-neighbor edges from lattice. """ edges = set() for i in range(lattice.Ns): neighbors = lattice.get_nn(i) if not neighbors: continue for j in neighbors: if lattice.wrong_nei(j): continue # Canonical edge (min, max) to avoid duplicates a, b = sorted((int(i), int(j))) if a != b: edges.add((a, b)) return sorted(edges) def _infer_bipartite_coloring(adjacency: List[List[int]]) -> Optional[List[int]]: """ Try to 2-color the graph. Returns list of 0/1 colors or None if not bipartite. """ ns = len(adjacency) colors = [-1] * ns for start in range(ns): if colors[start] != -1 or not adjacency[start]: continue colors[start] = 0 queue = [start] while queue: node = queue.pop(0) for neigh in adjacency[node]: if neigh < 0: continue if colors[neigh] == -1: colors[neigh] = colors[node] ^ 1 queue.append(neigh) elif colors[neigh] == colors[node]: return None # Not bipartite # Fill any disconnected single nodes for idx, neighbours in enumerate(adjacency): if colors[idx] == -1: colors[idx] = 0 return colors def _boundary_masks(positions: np.ndarray, lattice: Lattice, *, tol_factor: float = 1e-6) -> Tuple[np.ndarray, np.ndarray]: """ Identify sites on the spatial boundaries of the lattice. """ if lattice.dim == 0 or positions.size == 0: return np.zeros(positions.shape[0], dtype=bool), np.ones(positions.shape[0], dtype=bool) mins = positions.min(axis=0) maxs = positions.max(axis=0) span = np.maximum(maxs - mins, tol_factor) tol = span * tol_factor boundary_mask = np.zeros(positions.shape[0], dtype=bool) # Check boundaries in each dimension for axis in range(min(positions.shape[1], 3)): boundary_axis = (np.isclose(positions[:, axis], mins[axis], atol=tol[axis]) | np.isclose(positions[:, axis], maxs[axis], atol=tol[axis])) boundary_mask |= boundary_axis interior_mask = ~boundary_mask return boundary_mask, interior_mask def _draw_primitive_cell(axis: Axes, origin: np.ndarray, basis_vectors: List[np.ndarray], dim: int, **kwargs) -> None: """ Draw the primitive unit cell vectors from an origin. """ if not basis_vectors: return color = kwargs.get("color", "0.4") linestyle = kwargs.get("linestyle", ":") linewidth = kwargs.get("linewidth", 1.0) if dim == 1 and len(basis_vectors) >= 1: points = np.vstack([origin, origin + basis_vectors[0]]) axis.plot(points[:, 0], np.zeros_like(points[:, 0]), color=color, linestyle=linestyle, linewidth=linewidth) elif dim == 2 and len(basis_vectors) >= 2: a1, a2 = basis_vectors[:2] corners = np.array([origin, origin + a1, origin + a1 + a2, origin + a2, origin]) axis.plot(corners[:, 0], corners[:, 1], color=color, linestyle=linestyle, linewidth=linewidth) elif dim == 3 and len(basis_vectors) >= 3: raise NotImplementedError("3D primitive cell plotting is not implemented yet.") def _draw_boundary_annotations( axis : Axes, positions : np.ndarray, lattice : Lattice, *, periodic_color : str, open_color : str, offset_fraction : float, ) -> None: """ Draw annotations indicating OBC/PBC on the plot axes. """ if positions.shape[1] < 2: return mins = positions.min(axis=0) maxs = positions.max(axis=0) mid = (mins + maxs) / 2.0 diag = np.linalg.norm(maxs[:2] - mins[:2]) padding = offset_fraction * (diag if diag > 0 else 1.0) flags = lattice.periodic_flags() labels = ("x", "y", "z") def _annotate_axis(axis_index: int, label: str) -> None: is_periodic = bool(flags[axis_index]) color = periodic_color if is_periodic else open_color # Helper for common styles style_kw = dict(color=color, lw=1.2, linestyle="--", alpha=0.8) arrow_kw = dict(arrowstyle="->", color=color, lw=1.5, linestyle="--") text_kw = dict(color=color, bbox=dict(facecolor="white", edgecolor="none", alpha=0.7, pad=1)) if axis_index == 0: # X-direction boundaries y = mid[1] if positions.shape[1] > 1 else 0.0 if is_periodic: axis.annotate(f"PBC {label}", xy=(maxs[0], y), xytext=(maxs[0] + padding, y), ha="left", va="center", arrowprops=arrow_kw, **text_kw) axis.annotate("", xy=(mins[0], y), xytext=(mins[0] - padding, y), arrowprops=arrow_kw) else: # Draw lines indicating open boundaries axis.plot([mins[0], mins[0]], [mins[1], maxs[1]], **style_kw) axis.plot([maxs[0], maxs[0]], [mins[1], maxs[1]], **style_kw) axis.text(mid[0], maxs[1] + padding, f"Open {label}", ha="center", va="bottom", **text_kw) elif axis_index == 1: # Y-direction boundaries x = mid[0] if is_periodic: axis.annotate(f"PBC {label}", xy=(x, maxs[1]), xytext=(x, maxs[1] + padding), ha="center", va="bottom", arrowprops=arrow_kw, **text_kw) axis.annotate("", xy=(x, mins[1]), xytext=(x, mins[1] - padding), arrowprops=arrow_kw) else: axis.plot([mins[0], maxs[0]], [mins[1], mins[1]], **style_kw) axis.plot([mins[0], maxs[0]], [maxs[1], maxs[1]], **style_kw) axis.text(mins[0] - padding, mid[1], f"Open {label}", ha="right", va="center", rotation=90, **text_kw) for idx in range(min(positions.shape[1], 2)): _annotate_axis(idx, labels[idx]) def plot_lattice_structure( lattice : Lattice, *, ax : Optional[Axes] = None, show_indices : bool = False, highlight_boundary : bool = True, # related to boundary highlighting show_axes : bool = False, edge_color : str = "0.5", node_color : str = "tab:blue", boundary_node_color : str = "tab:red", periodic_color : str = "tab:orange", open_color : str = "tab:green", bond_colors : dict = { 0 : "tab:red", 1 : "tab:blue", 2: "tab:green" }, # styling node_size : int = 30, edge_alpha : float = 0.7, label_padding : float = 0.05, label_fontsize : int = 8, boundary_offset : float = 0.05, # general plot settings figsize : Optional[Tuple[float, float]] = None, fix_aspect : bool = True, title : Optional[str] = None, title_kwargs : Optional[Dict[str, object]] = None, tight_layout : bool = True, elev : Optional[float] = None, azim : Optional[float] = None, partition_colors : Optional[Tuple[str, ...]] = None, show_periodic_connections : bool = True, show_primitive_cell : bool = True, **scatter_kwargs, ) -> Tuple[Figure, Axes]: r""" Visualise lattice geometry with connectivity, boundary cues, and sublattices. This function draws nodes and edges based on nearest-neighbor connectivity. It highlights boundaries, annotates PBCs, and can color nodes by bipartite partitioning if applicable. Parameters ---------- lattice : Lattice The lattice model. show_indices : bool If True, annotates nodes with their site indices. highlight_boundary : bool If True, draws boundary nodes with a distinct color/edge. show_axes : bool If False, hides the coordinate axes for a cleaner diagram. partition_colors : tuple of str, optional Colors to use for bipartite/sublattice coloring. If provided, nodes are colored based on sublattice parity. show_periodic_connections : bool If True, indicates wrap-around connections textually or graphically. show_primitive_cell : bool If True, overlays the primitive unit cell vectors/box. fix_aspect : bool, default=True If True, preserve equal axis scaling in 2D plots. Set to ``False`` to let the requested ``figsize`` control the on-screen aspect. ... other parameters mirror plot_real_space ... """ coords = _ensure_numpy(lattice.rvectors) target_dim = lattice.dim if lattice.dim else coords.shape[1] dim = max(1, min(coords.shape[1], target_dim, 3)) coords = coords[:, :dim] fig, axis = _init_axes(ax, dim) if figsize is not None and axis is fig.axes[0]: fig.set_size_inches(*figsize, forward=True) if dim == 3 and (elev is not None or azim is not None): current_elev = elev if elev is not None else getattr(axis, "elev", None) current_azim = azim if azim is not None else getattr(axis, "azim", None) axis.view_init(elev=current_elev, azim=current_azim) # Compute Connectivity edges = _gather_nn_edges(lattice) adjacency = [[] for _ in range(lattice.Ns)] for i, j in edges: adjacency[i].append(j) adjacency[j].append(i) # Periodic Edges Detection periodic_neighbors = defaultdict(list) periodic_label_counts = defaultdict(int) typical_distance = None if edges: # Heuristic: distances significantly larger than min distance are likely PBC wraps all_dists = [np.linalg.norm(coords[j] - coords[i]) for i, j in edges] valid_dists = [d for d in all_dists if d > 1e-9] if valid_dists: typical_distance = min(valid_dists) # Draw edges for i, j in edges: start = coords[i] end = coords[j] dist = np.linalg.norm(end - start) bond_type = lattice.bond_type(i, j) # Check if this edge wraps around the boundary is_periodic = False if typical_distance is not None: # 1.5x factor is a safe heuristic for regular lattices if dist > typical_distance * 1.5: is_periodic = True # Draw logic linestyle = "--" if is_periodic else "-" if is_periodic: periodic_neighbors[i].append(j) periodic_neighbors[j].append(i) # Plot the line line_args = dict(color = bond_colors.get(bond_type, edge_color), alpha = edge_alpha, linestyle = linestyle, linewidth = 1.0, zorder = 2) if dim == 3: axis.plot([start[0], end[0]], [start[1], end[1]], [start[2], end[2]], **line_args) elif dim == 2: axis.plot([start[0], end[0]], [start[1], end[1]], **line_args) else: # 1D axis.plot([start[0], end[0]], [0.0, 0.0], **line_args) # Node Coloring node_face_colors = [node_color] * lattice.Ns if partition_colors: partitions = _infer_bipartite_coloring(adjacency) if partitions is not None: palette = partition_colors node_face_colors = [palette[partitions[i] % len(palette)] for i in range(lattice.Ns)] # Draw Nodes scatter_defaults = dict(s=node_size, zorder=3, **scatter_kwargs) if dim == 1: axis.scatter(coords[:, 0], np.zeros_like(coords[:, 0]), c=node_face_colors, **scatter_defaults) axis.set_ylim(-0.5, 0.5) elif dim == 2: axis.scatter(coords[:, 0], coords[:, 1], c=node_face_colors, **scatter_defaults) _apply_planar_aspect(axis, fix_aspect=fix_aspect) else: axis.scatter(coords[:, 0], coords[:, 1], coords[:, 2], c=node_face_colors, **scatter_defaults) # Boundary Highlight boundary_mask, _ = _boundary_masks(coords, lattice) if highlight_boundary and np.any(boundary_mask): b_coords = coords[boundary_mask] b_args = dict(facecolors="none", edgecolors=boundary_node_color, s=node_size*1.2, linewidths=1.2, zorder=4) if dim == 3: axis.scatter(b_coords[:, 0], b_coords[:, 1], b_coords[:, 2], **b_args) elif dim == 2: axis.scatter(b_coords[:, 0], b_coords[:, 1], **b_args) else: axis.scatter(b_coords[:, 0], np.zeros_like(b_coords[:, 0]), **b_args) # Spatial Limits & Visibility # Use a larger margin if we have boundary annotations or periodic labels margin = 0.08 if dim == 2: margin = max(margin, boundary_offset * 1.5) _apply_spatial_limits(axis, coords, dim, show_axes, margin=margin, fix_aspect=fix_aspect) if title: kw = {"pad": 15} if title_kwargs: kw.update(title_kwargs) axis.set_title(title, **kw) # Indices & Annotations node_label_positions = {} if show_indices: mins = coords.min(axis=0) maxs = coords.max(axis=0) diag = np.linalg.norm(maxs - mins) if coords.size else 1.0 offset = label_padding * (diag if diag > 0 else 1.0) for idx, point in enumerate(coords): label_pos = point.copy() if dim >= 1: label_pos[0] += offset if dim >= 2: label_pos[1] += offset if dim >= 3: label_pos[2] += offset txt_args = dict(ha="center", va="center", color="black", fontsize=label_fontsize, bbox=dict(facecolor="white", edgecolor="none", alpha=0.7, pad=1), zorder=6) if dim == 3: axis.text(label_pos[0], label_pos[1], label_pos[2], str(idx), **txt_args) elif dim == 2: axis.text(label_pos[0], label_pos[1], str(idx), **txt_args) else: axis.text(label_pos[0], offset, str(idx), **txt_args) node_label_positions[idx] = label_pos # Boundary Annotations (2D only) if dim == 2: _draw_boundary_annotations(axis, coords, lattice, periodic_color=periodic_color, open_color=open_color, offset_fraction=boundary_offset) # Periodic Connections Text if show_periodic_connections and periodic_neighbors: diag_extent = np.linalg.norm(coords.max(axis=0) - coords.min(axis=0)) or 1.0 base_offset = label_padding * diag_extent * 0.6 for idx, neighbours in periodic_neighbors.items(): if not neighbours: continue anchor = node_label_positions.get(idx, coords[idx]) label = "↔ " + ",".join(str(n) for n in sorted(set(neighbours))) count = periodic_label_counts[idx] periodic_label_counts[idx] += 1 shift = base_offset * (count + 1) pos = anchor.copy() if dim == 1: pos[0] = pos[0] # Usually stack vertically in 1D? Or just use y-offset y_pos = shift elif dim == 2: pos[1] += shift # Shift up y txt_args = dict(color=periodic_color, fontsize=8, ha="center", va="bottom", bbox=dict(facecolor="white", edgecolor="none", alpha=0.7, pad=0.4), zorder=5) if dim == 3: pos[2] += shift axis.text(pos[0], pos[1], pos[2], label, **txt_args) elif dim == 2: axis.text(pos[0], pos[1], label, **txt_args) else: axis.text(pos[0], shift, label, **txt_args) # Primitive Cell if show_primitive_cell: # Try to find basis vectors basis_vectors = [] for attr in ("a1", "a2", "a3"): vec = getattr(lattice, attr, None) if vec is not None: vec = np.asarray(vec).flatten() if vec.size >= dim and np.linalg.norm(vec[:dim]) > 1e-9: basis_vectors.append(vec[:dim]) if basis_vectors: origin = coords.min(axis=0) _draw_primitive_cell(axis, origin, basis_vectors, dim) if tight_layout: _finalise_figure(fig, top_padding=0.92) return fig, axis # ============================================================================== #! Region Plotting # ============================================================================== def _region_palette(n: int) -> List: """ Return *n* high-contrast, distinguishable colours for region plots. Uses a hand-picked palette for small *n* (≤8) and falls back to the ``tab20`` colour-map for larger numbers. """ # High-contrast hand-picked palette (colour-blind friendly order) _BASE = [ "#1f77b4", # blue "#d62728", # red "#2ca02c", # green "#ff7f0e", # orange "#9467bd", # purple "#8c564b", # brown "#e377c2", # pink "#17becf", # cyan ] if n <= len(_BASE): return _BASE[:n] import matplotlib.cm as cm cmap = cm.get_cmap("tab20", max(n, 20)) return [cmap(i) for i in range(n)] def _normalize_site_indices(indices_like: Any, n_sites: int) -> List[int]: """Convert supported index containers to sorted unique valid site indices.""" if indices_like is None: return [] if isinstance(indices_like, np.ndarray): raw = indices_like.ravel().tolist() elif isinstance(indices_like, (list, tuple, set)): raw = list(indices_like) else: return [] out: Set[int] = set() for item in raw: if isinstance(item, (int, np.integer)): idx = int(item) if 0 <= idx < n_sites: out.add(idx) return sorted(out) def _extract_region_indices(region_spec: Any, n_sites: int, component: str = "A") -> List[int]: """ Normalize one region descriptor to a site-index list. Supports: - plain lists/arrays of indices, - dicts containing region components, - Region-like objects (with .get/.A/.to_dict()). """ direct = _normalize_site_indices(region_spec, n_sites) if direct: return direct comp = str(component).upper() if isinstance(region_spec, dict): if comp in region_spec: return _normalize_site_indices(region_spec.get(comp), n_sites) merged: List[int] = [] for value in region_spec.values(): v = _normalize_site_indices(value, n_sites) if v: merged.extend(v) return sorted(set(merged)) get_fn = getattr(region_spec, "get", None) if callable(get_fn): try: cand = get_fn(comp, None) except Exception: cand = None cand_norm = _normalize_site_indices(cand, n_sites) if cand_norm: return cand_norm if hasattr(region_spec, comp): cand_norm = _normalize_site_indices(getattr(region_spec, comp), n_sites) if cand_norm: return cand_norm to_dict_fn = getattr(region_spec, "to_dict", None) if callable(to_dict_fn): try: mapping = to_dict_fn() except Exception: mapping = None if isinstance(mapping, dict): if comp in mapping: return _normalize_site_indices(mapping.get(comp), n_sites) merged: List[int] = [] for value in mapping.values(): v = _normalize_site_indices(value, n_sites) if v: merged.extend(v) return sorted(set(merged)) return [] def plot_regions( lattice : Lattice, regions : Union[Dict[str, Any], Any], *, ax : Optional[Axes] = None, # showers show_indices : bool = False, show_system : bool = True, show_complement : bool = False, show_labels : bool = True, show_overlaps : bool = True, show_bonds : bool = False, show_legend : bool = True, # Other points origin : Optional[np.ndarray] = None, system_color : str = 'lightgray', system_alpha : float = 0.25, region_colors : Optional[Dict[str, str]] = None, region_alpha : float = 0.6, complement_color : str = 'lightgray', complement_alpha : float = 0.3, overlap_color : str = 'red', fill : bool = False, fill_alpha : float = 0.2, # region styling blob_radius : Optional[float] = None, blob_alpha : float = 0.12, marker_size : int = 60, edge_width : float = 1.5, # general plot settings figsize : Optional[Tuple[float, float]] = None, fix_aspect : bool = True, title : Optional[str] = None, title_kwargs : Optional[Dict[str, object]] = None, tight_layout : bool = True, elev : Optional[float] = None, azim : Optional[float] = None, # Region labels and legend region_descriptions : Optional[Dict[str, str]] = None, legend_loc : str = 'best', legend_fontsize : int = 9, legend_bbox : tuple = (1.05, 1), # Label styling label_fontsize : int = 11, label_offset : float = 1.2, # Indices and axes indices_padding : float = 0.05, show_axes : bool = False, region_component : str = "A", **scatter_kwargs, ) -> Tuple[Figure, Axes]: """ Plot labelled lattice regions with distinct colours and informative legend. Features -------- - High-contrast colours that are distinguishable for every region. - Labels placed *radially outward* from the plot centre so they never overlap even for Kitaev-Preskill-style pie-slice sectors. - Legend entries include the region name, site count, and optional human-readable description. - Optional convex-hull fill, translucent per-site blobs, and intra-region bond drawing. Parameters ---------- lattice : Lattice The lattice object. regions : Dict[str, Any] or Region-like Region mapping or Region-like objects. Region-like values are resolved using ``region_component`` (default ``"A"``). region_descriptions : dict[str, str], optional Optional human-readable description per region key that is appended to the legend entry (e.g. ``{'A': 'sector 0°-120°'}``). legend_loc : str Matplotlib legend location string (default ``'best'``). legend_fontsize : int Font size for legend entries (default 9). label_fontsize : int Font size for region labels drawn on the plot (default 11). label_offset : float Controls how far the label is pushed radially outward from the region centroid (default 1.2 x distance from plot centre to centroid). Values > 1 push the label outside the region. show_bonds : bool Draw NN bonds coloured by region. region_component : str Component to extract from Region-like entries (default ``"A"``). Ignored for plain index lists. blob_radius, blob_alpha : float Per-site circle patches (2D only). fill, fill_alpha : bool, float Convex-hull polygon fill (2D, requires scipy). (other parameters identical to previous version) show_axes : bool If False (default), hides the coordinate axes for a cleaner diagram. fix_aspect : bool, default=True If True, preserve equal axis scaling in 2D plots. Set to ``False`` to let the requested ``figsize`` control the on-screen aspect. """ coords = _ensure_numpy(lattice.rvectors) target_dim = lattice.dim if lattice.dim else coords.shape[1] dim = max(1, min(coords.shape[1], target_dim, 3)) coords = coords[:, :dim] fig, axis = _init_axes(ax, dim) if figsize is not None and axis is fig.axes[0]: fig.set_size_inches(*figsize, forward=True) # Set 3D view if requested and applicable if dim == 3 and (elev is not None or azim is not None): axis.view_init( elev=elev if elev is not None else getattr(axis, "elev", None), azim=azim if azim is not None else getattr(axis, "azim", None), ) # background: all system sites if show_system: _sc = dict(color=system_color, alpha=system_alpha, marker='o', s=30, zorder=0) if dim <= 2: y = coords[:, 1] if dim == 2 else np.zeros(len(coords)) axis.scatter(coords[:, 0], y, **_sc) else: axis.scatter(coords[:, 0], coords[:, 1], coords[:, 2], **_sc) # bonds if show_bonds and dim <= 2: for i in range(len(coords)): for j in lattice.get_nn(i): if lattice.wrong_nei(j): continue j = int(j) if i < j: # Avoid double counting ri, rj = coords[i, :2], coords[j, :2] axis.plot([ri[0], rj[0]], [ri[1], rj[1]], color=system_color, lw=0.8, alpha=system_alpha * 0.8, zorder=0) if origin is not None: _sc = dict(color='black', alpha=0.8, marker='X', s=100, zorder=5) if dim <= 2: y = origin[1] if dim == 2 else 0.0 axis.scatter(origin[0], y, **_sc) else: axis.scatter(origin[0], origin[1], origin[2], **_sc) # Normalize inputs: convert Region/dict/list descriptors to list[int] if not isinstance(regions, dict): regions = {"region": regions} normalized_regions : Dict[str, List[int]] = {} n_sites = len(coords) for name, spec in regions.items(): normalized_regions[str(name)] = _extract_region_indices(spec, n_sites=n_sites, component=region_component) regions = normalized_regions # site membership bookkeeping all_region_sites = set() site_counts: Dict[int, int] = {} for indices in regions.values(): for idx in indices: all_region_sites.add(idx) site_counts[idx] = site_counts.get(idx, 0) + 1 # complement and overlap sites for optional distinct styling complement_sites = [i for i in range(len(coords)) if i not in all_region_sites] overlap_sites = [i for i, c in site_counts.items() if c > 1] # complement — small faint dots if show_complement and complement_sites: cc = coords[complement_sites] _ca = dict(color=complement_color, alpha=complement_alpha, marker='o', s=marker_size * 0.25, edgecolors='none', zorder=0) if dim <= 2: y = cc[:, 1] if dim == 2 else np.zeros(len(cc)) axis.scatter(cc[:, 0], y, **_ca) else: axis.scatter(cc[:, 0], cc[:, 1], cc[:, 2], **_ca) # Global centroid (used for radial label placement) # colour palette palette = _region_palette(len(regions)) global_com = np.mean(coords, axis=0) region_descriptions = region_descriptions or {} # draw each region for i, (name, indices) in enumerate(regions.items()): if not indices: continue rc = coords[indices] color = palette[i % len(palette)] if region_colors is None else region_colors.get(name, palette[i % len(palette)]) alpha = region_alpha n_pts = len(indices) # Build informative legend text desc = region_descriptions.get(name, "") lbl = f"{name}: {n_pts} sites" if desc: lbl += f" — {desc}" # Convex-hull fill (2D) - allows visualising the overall shape of the region even if the sites are sparse if fill and dim == 2 and ConvexHull is not None and len(rc) >= 3: try: hull = ConvexHull(rc) hp = rc[hull.vertices] axis.fill(hp[:, 0], hp[:, 1], color=color, alpha=alpha * fill_alpha, zorder=1) except Exception: pass # Per-site blobs (2D) - gives a visual sense of the site density and extent of the region if blob_radius is not None and dim == 2: from matplotlib.patches import Circle as _Circle from matplotlib.collections import PatchCollection as _PC circles = [_Circle((x, y), blob_radius) for x, y in rc[:, :2]] pc = _PC(circles, facecolors=color, edgecolors='none', alpha=alpha * blob_alpha, zorder=1) axis.add_collection(pc) # Intra-region NN bonds (2D) if show_bonds and dim == 2: idx_set = set(indices) for si in indices: for nj in lattice.get_nn(si): if lattice.wrong_nei(nj): continue nj = int(nj) if nj in idx_set and nj > si: ri, rj = coords[si, :2], coords[nj, :2] axis.plot([ri[0], rj[0]], [ri[1], rj[1]], color=color, lw=1.2, alpha=alpha * 0.55, zorder=2) # Scatter markers sc_kw = dict(color=color, marker='o', s=marker_size, edgecolors='black', linewidths=edge_width * 0.5, label=lbl, zorder=3, alpha=alpha) sc_kw.update(scatter_kwargs) if dim <= 2: y = rc[:, 1] if dim == 2 else np.zeros(len(rc)) axis.scatter(rc[:, 0], y, **sc_kw) else: axis.scatter(rc[:, 0], rc[:, 1], rc[:, 2], **sc_kw) # Region label — placed radially outward from global centroid # Place near one of the first sites, not in centroid if show_labels and dim == 2 and len(rc) > 0: # Determine system scale for relative padding spans = rc.max(axis=0) - rc.min(axis=0) diag = np.sqrt(np.sum(spans**2)) rel_pad = 0.15 * (diag if diag > 1e-9 else 1.0) com = np.mean(rc[:, :2], axis=0) direc = com - global_com[:2] norm = np.linalg.norm(direc) if norm < 1e-9: direc = np.array([0.0, 1.0]) else: direc = direc / norm lbl_pos = com + direc * norm * (label_offset - 1.0) + direc * rel_pad axis.annotate( name, xy=com, xytext=lbl_pos, fontsize=label_fontsize, fontweight='bold', color=color, ha='center', va='center', arrowprops=dict(arrowstyle='->', color=color, lw=1.2), bbox=dict(boxstyle='round,pad=0.25', fc='white', ec=color, alpha=0.85, lw=1.0), zorder=5, ) # Fallback label placement for non-2D or if no sites (just put at centroid) elif show_labels and dim != 2 and len(rc) > 0: com = np.mean(rc, axis=0) txt_kw = dict(fontsize=label_fontsize, fontweight='bold', color=color, ha='center', va='center', bbox=dict(fc='white', ec=color, alpha=0.8, pad=1.0)) if dim == 1: axis.text(com[0], 0, name, **txt_kw) else: axis.text(com[0], com[1], com[2], name, **txt_kw) # overlap highlight -> draw on top of everything else with a distinct style if show_overlaps and overlap_sites: oc = coords[overlap_sites] _oa = dict(color='none', edgecolors=overlap_color, marker='o', s=marker_size * 1.5, linewidths=edge_width * 1.5, label=f'Overlaps ({len(overlap_sites)})', zorder=4) if dim <= 2: y = oc[:, 1] if dim == 2 else np.zeros(len(oc)) axis.scatter(oc[:, 0], y, **_oa) else: axis.scatter(oc[:, 0], oc[:, 1], oc[:, 2], **_oa) # site-index annotations if show_indices: _annotate_indices(axis, coords, padding=indices_padding) # Spatial Limits & Visibility # Use a larger margin if we have region labels to avoid clipping margin = 0.08 if show_labels: # Increase margin to accommodate labels and their leader lines margin = max(margin, (label_offset - 1.0) * 0.4 + 0.1) _apply_spatial_limits(axis, coords, dim, show_axes, margin=margin, fix_aspect=fix_aspect) # title if title: kw = {"pad": 12} if title_kwargs: kw.update(title_kwargs) axis.set_title(title, **kw) elif ax is None: n_total = len(coords) n_covered = len(all_region_sites) cov_pct = n_covered / n_total * 100 if n_total else 0 parts = [f"Regions ({len(regions)}) — {n_covered}/{n_total} sites ({cov_pct:.0f}%)"] if overlap_sites: parts.append(f", {len(overlap_sites)} overlaps") # legend (deduplicated, compact) if show_legend: handles, labels = axis.get_legend_handles_labels() by_label = dict(zip(labels, handles)) if by_label: axis.legend( by_label.values(), by_label.keys(), loc =legend_loc, fontsize =legend_fontsize, bbox_to_anchor =legend_bbox, framealpha =0.90, edgecolor ='lightgray', fancybox =True, handletextpad =0.3, labelspacing =0.25, borderpad =0.4, handlelength =1.2, markerscale =0.7, **scatter_kwargs, ) if tight_layout: _finalise_figure(fig) return fig, axis # ============================================================================== # K-space / Brillouin-zone with High-Symmetry Points # ============================================================================== def plot_high_symmetry_points( lattice : Lattice, *, path : Optional[Union[List[str], str, Iterable[Tuple[str, Iterable[float]]]]] = None, selection : Optional[Any] = None, ax : Optional[Axes] = None, show_kpoints : bool = True, show_bz : bool = True, show_path : bool = True, show_matched_kpoints : bool = True, bz_facecolor : str = "lavender", bz_edgecolor : str = "slategrey", bz_alpha : float = 0.20, kpoint_color : str = "C0", kpoint_alpha : float = 0.35, kpoint_size : int = 15, path_color : str = "crimson", path_linewidth : float = 1.8, matched_kpoint_color : str = "goldenrod", matched_kpoint_alpha : float = 1.0, matched_kpoint_size : int = 44, matched_kpoint_marker : str = "o", matched_kpoint_edgecolor: str = "black", hs_marker_size : int = 90, hs_marker_facecolor : str = "white", hs_marker_edgecolor : str = "black", hs_font_size : int = 13, hs_label_kwargs : Optional[Dict[str, object]] = None, hs_plot : str = "markers", # "none", "markers", "labels", or "both" points_per_seg : int = 40, path_match_tol : Optional[float] = None, figsize : Optional[Tuple[float, float]] = None, fix_aspect : bool = True, title : Optional[str] = None, title_kwargs : Optional[Dict[str, object]] = None, tight_layout : bool = True, extend : bool = False, extend_copies : Optional[Union[int, Iterable[int]]] = None, nx : int = 1, ny : int = 1, nz : int = 1, extended_kpoint_color : str = "gray", extended_kpoint_alpha : float = 0.15, extended_bz_facecolor : str = "lightgray", extended_bz_edgecolor : str = "dimgray", extended_bz_alpha : float = 0.10, show_background_bz : bool = False, # legend legend_kwargs : Optional[Dict[str, object]] = None, **kwargs, ) -> Tuple[Figure, Axes]: r""" Plot the Brillouin zone, high-symmetry path, and sampled reciprocal mesh. This view combines exact reciprocal-space geometry from the lattice with an ideal high-symmetry path and, optionally, the subset of sampled k-points that genuinely match that path within a configurable tolerance. Parameters ---------- lattice : Lattice Lattice object providing reciprocal vectors, sampled ``kvectors``, and optionally ``kvectors_frac`` and ``high_symmetry_points()``. path : list[str], str, or iterable[(label, frac)], optional High-symmetry path specification. If omitted, the lattice default path is used. selection : object, optional Precomputed output of ``lattice.bz_path_points(...)``. When provided, this exact matched set is used for path and matched-point rendering. ax : Axes, optional Existing matplotlib axes. If omitted, a new figure and axes are created. show_kpoints : bool, default=True Draw sampled reciprocal-space mesh points. show_bz : bool, default=True Draw the first Brillouin zone. show_path : bool, default=True Draw the ideal high-symmetry path. show_matched_kpoints : bool, default=True Highlight sampled k-points whose distance to the path is within the matching tolerance. bz_facecolor, bz_edgecolor, bz_alpha Style of the first Brillouin zone. kpoint_color, kpoint_alpha, kpoint_size Style of the original sampled k-mesh. path_color, path_linewidth Style of the ideal path. matched_kpoint_color, matched_kpoint_alpha, matched_kpoint_size, matched_kpoint_marker, matched_kpoint_edgecolor Style of valid matched mesh points. hs_marker_size, hs_marker_facecolor, hs_marker_edgecolor Style of exact high-symmetry vertices. hs_font_size, hs_label_kwargs Style of high-symmetry labels. hs_plot : {"none", "markers", "labels", "both"}, default="markers" Whether to draw exact high-symmetry markers, labels, or both. points_per_seg : int, default=40 Number of interpolation points per path segment for the ideal path. path_match_tol : float, optional Distance tolerance for highlighting mesh points near the drawn path. If omitted, a mesh-based Cartesian tolerance is estimated from the sampled reciprocal points. fix_aspect : bool, default=True If True, preserve equal axis scaling in 2D plots. Set to ``False`` to let the requested ``figsize`` control the on-screen aspect. extend : bool, default=False Draw translated copies of the sampled k-mesh. extend_copies : int or iterable[int], optional Number of reciprocal-cell copies per direction. In 2D, ``extend_copies=1`` includes the first shell around the first Brillouin zone and ``extend_copies=2`` includes the second shell as well. nx, ny, nz : int, default=(1, 1, 1) Legacy per-direction extension counts used when ``extend_copies`` is not specified. bz_upscale : float, default=1.1 Factor by which the maximum reciprocal-vector norm is multiplied to determine the plot limits when ``show_bz=True``. extended_kpoint_color, extended_kpoint_alpha Style of translated mesh points. extended_bz_facecolor, extended_bz_edgecolor, extended_bz_alpha Style of translated Brillouin-zone copies. show_background_bz : bool, default=False Draw translated Brillouin-zone copies behind the mesh. legend_kwargs : dict, optional Extra keyword arguments passed to ``axis.legend``. **kwargs Low-level style overrides such as ``zorder_path``, ``zorder_kpoints``, ``marker_hs`` or ``marker_extend``. Returns ------- fig, ax : Figure, Axes Matplotlib figure and axes containing the plot. """ kvecs_full = _ensure_numpy(lattice.kvectors) target_dim = lattice.dim if lattice.dim else kvecs_full.shape[1] dim = max(1, min(kvecs_full.shape[1], target_dim, 3)) coords = kvecs_full[:, :dim] kfrac = getattr(lattice, "kvectors_frac", None) fig, axis = _init_axes(ax, dim) if figsize is not None and axis is fig.axes[0]: fig.set_size_inches(*figsize, forward=True) # Set 3D view if applicable if dim == 3: axis.view_init(elev=getattr(axis, "elev", 30.0), azim=getattr(axis, "azim", 45.0)) # Determine how many extended copies to generate in each direction if extend_copies is None: copy_spec = nx if dim == 1 else ((nx, ny) if dim == 2 else (nx, ny, nz)) elif np.isscalar(extend_copies): copy_spec = int(extend_copies) else: copy_values = tuple(int(v) for v in extend_copies) if len(copy_values) < dim: raise ValueError("extend_copies must provide at least one value per reciprocal-space dimension") copy_spec = copy_values[:dim] # Check the BZ path and find the nearest k-points along it. if selection is None: selection = lattice.bz_path_points( path=path, points_per_seg=points_per_seg, k_vectors=kvecs_full, k_vectors_frac=kfrac, tol=path_match_tol, periodic=False, ) hs = lattice.high_symmetry_points() resolved_path = lattice.default_resolve_path(path if path is not None else hs) legend_kwargs = {} if legend_kwargs is None else dict(legend_kwargs) label_kmesh = kwargs.get("label_kmesh", "k-mesh") label_extended = kwargs.get("label_extended", "extended mesh") label_matched = kwargs.get("label_matched", "matched path points") plotted_coords = [coords] round_scale = 1e-10 # Determine the BZ extent for auto-scaling bz_upscale = kwargs.pop("bz_upscale", 1.1) if show_bz: _draw_bz_region(axis, coords, dim=dim, lattice=lattice, facecolor=bz_facecolor, edgecolor=bz_edgecolor, alpha=bz_alpha, fix_aspect=fix_aspect, show_points=False, zorder=0.0, **kwargs) # Ensure the BZ is fully contained in the auto-scaling plotted_coords b_norms = [] for i in range(1, 4): # Reciprocal vectors are always 3D in the Lattice class, but may be 2D/1D for others vec = getattr(lattice, f"k{i}", getattr(lattice, f"b{i}", None)) if vec is not None: b_norms.append(np.linalg.norm(np.asarray(vec)[:dim])) if b_norms: # Use a factor that ensures BZ and some margin is visible. # 0.5*b is the BZ face, 0.75*b covers corners, 1.1*b is a generous margin. kmax = max(b_norms) * bz_upscale bbox = np.eye(dim) * kmax plotted_coords.append(bbox) plotted_coords.append(-bbox) if show_background_bz: bz_centers = lattice.wigner_seitz_shifts(copies=copy_spec, include_origin=False) seen_shifts = set() shifts = [] for shift in bz_centers: key = tuple(np.rint(shift / round_scale).astype(np.int64)) if key in seen_shifts: continue seen_shifts.add(key) if np.allclose(shift, 0.0): continue shifts.append(shift) # Draw the extended BZ regions first (if requested) so they appear behind the original points and path if show_bz: for shift in shifts: _draw_bz_region(axis, coords + shift, dim=dim, lattice=lattice, offset=shift, facecolor=extended_bz_facecolor, edgecolor=extended_bz_edgecolor, alpha=extended_bz_alpha, fix_aspect=fix_aspect, show_points=False, zorder=-0.1, **kwargs ) if len(shifts) > 0: plotted_coords.append(np.vstack([coords + shift for shift in shifts])) # Get the path points in Cartesian coordinates and plot the path segments if show_path: path_cart = selection.path_cart[:, :dim] plotted_coords.append(path_cart) zorder_path = kwargs.get("zorder_path", 2) if dim == 1: axis.plot(path_cart[:, 0], np.zeros(len(path_cart)), color=path_color, linewidth=path_linewidth, zorder=zorder_path) elif dim == 2: axis.plot(path_cart[:, 0], path_cart[:, 1], color=path_color, linewidth=path_linewidth, zorder=zorder_path) else: axis.plot(path_cart[:, 0], path_cart[:, 1], path_cart[:, 2], color=path_color, linewidth=path_linewidth, zorder=zorder_path) # Optionally extend the k-point mesh and plot the extended points with a distinct style if extend and show_kpoints: extended_k, _ = lattice.wigner_seitz_extend(k_points=coords, copies=copy_spec) original_keys = {tuple(np.rint(row / round_scale).astype(np.int64)) for row in coords} seen_ext_keys = set() ext_mask_list = [] for row in extended_k: key = tuple(np.rint(row / round_scale).astype(np.int64)) keep = key not in original_keys and key not in seen_ext_keys ext_mask_list.append(keep) if keep: seen_ext_keys.add(key) ext_mask = np.array(ext_mask_list, dtype=bool) # Plot the extended k-points with a distinct style, ensuring we only plot the new points and not duplicates of the original mesh if np.any(ext_mask): ext_coords = extended_k[ext_mask] plotted_coords.append(ext_coords) marker_extend = kwargs.get("marker_extend", "o") edgecolor_extend = kwargs.get("edgecolor_extend", "none") zorder_extend = kwargs.get("zorder_extend", 3) if dim == 1: axis.scatter(ext_coords[:, 0], np.zeros(len(ext_coords)), s=kpoint_size, color=extended_kpoint_color, alpha=extended_kpoint_alpha, marker=marker_extend, edgecolors=edgecolor_extend, zorder=zorder_extend, label=label_extended) elif dim == 2: axis.scatter(ext_coords[:, 0], ext_coords[:, 1], s=kpoint_size, color=extended_kpoint_color, alpha=extended_kpoint_alpha, marker=marker_extend, edgecolors=edgecolor_extend, zorder=zorder_extend, label=label_extended) else: axis.scatter(ext_coords[:, 0], ext_coords[:, 1], ext_coords[:, 2], s=kpoint_size, color=extended_kpoint_color, alpha=extended_kpoint_alpha, marker=marker_extend, edgecolors=edgecolor_extend, zorder=zorder_extend, label=label_extended) # Plot the original k-points on top of everything else (if requested) so they are visible even if they overlap with the path or extended points if show_kpoints: zorder_kpoints = kwargs.get("zorder_kpoints", 6) marker_kpoints = kwargs.get("marker_kpoints", "o") edgecolor_kpoints = kwargs.get("edgecolor_kpoints", "none") if dim == 1: axis.scatter(coords[:, 0], np.zeros(len(coords)), s=kpoint_size, color=kpoint_color, alpha=kpoint_alpha, marker=marker_kpoints, edgecolors=edgecolor_kpoints, zorder=zorder_kpoints, label=label_kmesh) elif dim == 2: axis.scatter(coords[:, 0], coords[:, 1], s=kpoint_size, color=kpoint_color, alpha=kpoint_alpha, marker=marker_kpoints, edgecolors=edgecolor_kpoints, zorder=zorder_kpoints, label=label_kmesh) else: axis.scatter(coords[:, 0], coords[:, 1], coords[:, 2], s=kpoint_size, color=kpoint_color, alpha=kpoint_alpha, marker=marker_kpoints, edgecolors=edgecolor_kpoints, zorder=zorder_kpoints, label=label_kmesh) # Plot the matched k-points along the path with a distinct style, ensuring we only plot them if there are matches and if the option is enabled if show_matched_kpoints and selection.has_matches: valid_mask = selection.matched_distances <= (selection.match_tolerance + 1e-14) valid_positions = np.flatnonzero(valid_mask) seen_match_indices = set() keep_positions = [] for pos in valid_positions: key_array = selection.matched_grid_indices if len(selection.matched_grid_indices) > 0 else selection.matched_indices idx = int(key_array[pos]) if idx in seen_match_indices: continue seen_match_indices.add(idx) keep_positions.append(pos) matched = selection.matched_cart[np.asarray(keep_positions, dtype=int), :dim] if keep_positions else np.zeros((0, dim), dtype=float) if len(matched) > 0: plotted_coords.append(matched) if dim == 1: axis.scatter(matched[:, 0], np.zeros(len(matched)), s=matched_kpoint_size, color=matched_kpoint_color, alpha=matched_kpoint_alpha, marker=matched_kpoint_marker, edgecolors=matched_kpoint_edgecolor, linewidths=0.9, zorder=7, label=label_matched) elif dim == 2: axis.scatter(matched[:, 0], matched[:, 1], s=matched_kpoint_size, color=matched_kpoint_color, alpha=matched_kpoint_alpha, marker=matched_kpoint_marker, edgecolors=matched_kpoint_edgecolor, linewidths=0.9, zorder=7, label=label_matched) else: axis.scatter(matched[:, 0], matched[:, 1], matched[:, 2], s=matched_kpoint_size, color=matched_kpoint_color, alpha=matched_kpoint_alpha, marker=matched_kpoint_marker, edgecolors=matched_kpoint_edgecolor, linewidths=0.9, zorder=7, label=label_matched) # Plot exact high-symmetry vertices, not the nearest interpolated path samples. if hs_plot != "none": hs_points = [] seen_hs = set() b1 = lattice.b1 b2 = lattice.b2 if lattice.dim >= 2 else np.zeros(3, dtype=float) b3 = lattice.b3 if lattice.dim >= 3 else np.zeros(3, dtype=float) for lbl, frac in resolved_path: frac_arr = np.zeros(3, dtype=float) frac_arr[:len(frac)] = np.asarray(frac, dtype=float) pt_obj = hs.get(lbl) if hs is not None and hasattr(hs, "get") else None cart3 = pt_obj.to_cartesian(b1, b2, b3) if pt_obj is not None else frac_arr[0] * b1 + frac_arr[1] * b2 + frac_arr[2] * b3 cart = cart3[:dim] key = tuple(np.rint(cart / round_scale).astype(np.int64)) if key in seen_hs: continue seen_hs.add(key) hs_points.append((lbl, cart)) if hs_points: hs_coords = np.array([cart for _, cart in hs_points], dtype=float) plotted_coords.append(hs_coords) if hs_plot in ["markers", "both"]: marker_hs = kwargs.get("marker_hs", "o") zorder_hs = kwargs.get("zorder_hs", 5) lw_hs = kwargs.get("lw_hs", 1.4) if dim == 1: axis.scatter(hs_coords[:, 0], np.zeros(len(hs_coords)), s=hs_marker_size, color=hs_marker_facecolor, edgecolors=hs_marker_edgecolor, linewidths=lw_hs, zorder=zorder_hs, marker=marker_hs) elif dim == 2: axis.scatter(hs_coords[:, 0], hs_coords[:, 1], s=hs_marker_size, color=hs_marker_facecolor, edgecolors=hs_marker_edgecolor, linewidths=lw_hs, zorder=zorder_hs, marker=marker_hs) else: axis.scatter(hs_coords[:, 0], hs_coords[:, 1], hs_coords[:, 2], s=hs_marker_size, color=hs_marker_facecolor, edgecolors=hs_marker_edgecolor, linewidths=lw_hs, zorder=zorder_hs, marker=marker_hs) if hs_plot in ["labels", "both"]: for lbl, cart in hs_points: pt = hs.get(lbl) if hs is not None and hasattr(hs, "get") else None text = pt.latex_label if pt is not None else str(lbl) hs_dict = hs_label_kwargs.copy() if hs_label_kwargs else {} hs_fw = hs_dict.pop("fontweight", "bold") hs_zord = hs_dict.pop("zorder", 8) bbox = hs_dict.pop("bbox", None) hs_xy = hs_dict.pop("xy", (8, 8)) hs_xy = hs_xy if isinstance(hs_xy, tuple) else (hs_xy.get(lbl, (8, 8)) if isinstance(hs_xy, dict) else (8, 8)) if dim == 3: text_kwargs = dict(fontsize=hs_font_size, fontweight=hs_fw, zorder=hs_zord, ha='center', va='center', **hs_dict) if bbox is not None: text_kwargs["bbox"] = bbox axis.text(cart[0], cart[1], cart[2], text, **text_kwargs) else: ann_kwargs = dict( xy=(cart[0], cart[1] if dim > 1 else 0.0), textcoords='offset points', xytext=hs_xy, fontsize=hs_font_size, fontweight=hs_fw, zorder=hs_zord, **hs_dict, ) if bbox is not None: ann_kwargs["bbox"] = bbox axis.annotate(text, **ann_kwargs) # Final spatial limits and legend plotted_stack = np.vstack([np.asarray(arr, dtype=float).reshape(-1, dim) for arr in plotted_coords if np.asarray(arr).size > 0]) _apply_spatial_limits(axis, plotted_stack, dim, True, labels=(r'$k_x$', r'$k_y$', r'$k_z$'), fix_aspect=fix_aspect) if dim == 2: axis.axhline(0, color='grey', lw=0.4, zorder=-1) axis.axvline(0, color='grey', lw=0.4, zorder=-1) if title: kw = {"pad": 12} if title_kwargs: kw.update(title_kwargs) axis.set_title(title, **kw) handles, labels_ = axis.get_legend_handles_labels() if handles: axis.legend(loc=legend_kwargs.pop("loc", "best"), fontsize=legend_kwargs.pop("fontsize", 9), framealpha=legend_kwargs.pop("framealpha", 0.90), edgecolor=legend_kwargs.pop("edgecolor", "lightgray"), fancybox=legend_kwargs.pop("fancybox", True), **legend_kwargs) if tight_layout: _finalise_figure(fig) return fig, axis # ============================================================================== # Plotter Class # ==============================================================================
[docs] @dataclass class LatticePlotter: """ Convenience wrapper bundling plotting helpers for a single lattice. Usage: lattice.plot.real_space() lattice.plot.structure(show_indices=True) lattice.plot.regions(regions_dict) """ lattice: Lattice
[docs] def real_space(self, **kwargs) -> Tuple[Figure, Axes]: """ Plot real-space sites. """ kwargs.setdefault("figsize", (5.0, 5.0)) return plot_real_space(self.lattice, **kwargs)
[docs] def reciprocal_space(self, **kwargs) -> Tuple[Figure, Axes]: """ Scatter-plot of reciprocal lattice vectors (k-points). Parameters mirror :func:`plot_real_space` -------------------------------------------------------------------------- lattice : Lattice The lattice object to plot. ax : Axes, optional Matplotlib axes to plot on. If None, a new figure is created. show_indices : bool, default=False If True, annotate each k-point with its index. show_axes : bool, default=True If False, hides the coordinate axes. color : str, default="C1" Color of the k-point markers. marker : str, default="o" Marker style. figsize : tuple, optional Figure size in inches (width, height). title : str, optional Title of the plot. elev, azim : float, optional Elevation and azimuth angles for 3D plots. extend_kpoints : bool, default=False If True, draw translated reciprocal-space copies around the original mesh. extend_copies : int or iterable of int, default=2 Number of copies per reciprocal direction used when ``extend_kpoints=True``. Scalars are applied to all active reciprocal directions. extend_tol : float, default=1e-10 Tolerance used to identify which extended points are already present in the original reciprocal mesh. **scatter_kwargs Include: - point_edgecolor: Color of the marker edges (default "white"). - point_zorder: Z-order for the scatter points (default 5). - color_extended: Color for translated copies (default "C2"). - edgecolor_extended: Edge color for translated copies (default "gray"). - marker_extended: Marker for translated copies (default ``marker``). - Any other valid arguments for `ax.scatter`. """ kwargs.setdefault("figsize", (5.0, 5.0)) return plot_reciprocal_space(self.lattice, **kwargs)
[docs] def brillouin_zone(self, **kwargs) -> Tuple[Figure, Axes]: """ Plot the Brillouin Zone. """ kwargs.setdefault("figsize", (5.0, 4.0)) return plot_brillouin_zone(self.lattice, **kwargs)
[docs] def structure(self, **kwargs) -> Tuple[Figure, Axes]: r""" Plot detailed lattice structure with connectivity. Parameters ---------- ax : Axes, optional Matplotlib axes to plot on. If None, a new figure is created. show_indices : bool If True, annotates nodes with their site indices. highlight_boundary : bool If True, draws boundary nodes with a distinct color/edge. show_axes : bool If False, hides the coordinate axes for a cleaner diagram. edge_color : str Color of the edges. node_color : str Color of the nodes. boundary_node_color : str Color of the boundary node edges. periodic_color : str Color for periodic boundary annotations. open_color : str Color for open boundary annotations. node_size : int Size of the node markers. edge_alpha : float Transparency of the edges. label_padding : float Fractional padding for node index labels. boundary_offset : float Fractional offset for boundary annotations. figsize : tuple, optional Figure size in inches (width, height). title : str, optional Title of the plot. title_kwargs : dict, optional Additional keyword arguments for the title. tight_layout : bool If True, applies tight layout to the figure. elev, azim : float, optional Elevation and azimuth angles for 3D plots. partition_colors : tuple of str, optional Colors to use for bipartite/sublattice coloring. If provided, nodes are colored based on sublattice parity. show_periodic_connections : bool If True, indicates wrap-around connections textually or graphically. show_primitive_cell : bool If True, overlays the primitive unit cell vectors/box. **scatter_kwargs Additional arguments passed to `ax.scatter`. """ kwargs.setdefault("figsize", (5.5, 5.5)) return plot_lattice_structure(self.lattice, **kwargs)
[docs] def regions(self, regions: Dict[str, List[int]], **kwargs) -> Tuple[Figure, Axes]: """ Plot specific regions on the lattice. Parameters ---------- regions : Dict[str, List[int]] Dictionary mapping region names to lists of site indices. show_system : bool If True, plot all lattice sites faintly in the background. system_color : str Color for background system sites. cmap : str Colormap name for distinct regions. blob_radius : float, optional If given, draw a translucent circle around each site. show_bonds : bool If True, draw intra-region NN bonds. ... other args mirror plot_real_space ... """ kwargs.setdefault("figsize", (6.0, 6.0)) if isinstance(regions, str): regions = {regions: self.lattice.get_region(regions, **kwargs)} return plot_regions(self.lattice, regions, **kwargs)
[docs] def bz_high_symmetry(self, **kwargs) -> Tuple[Figure, Axes]: """ Plot the Brillouin zone, high-symmetry path, and sampled reciprocal mesh. Parameters ---------- path : list[str], str, or iterable[(label, frac)], optional High-symmetry path specification. If omitted, the lattice default path is used. show_kpoints : bool, default=True Draw sampled reciprocal-space mesh points. show_bz : bool, default=True Draw the first Brillouin zone. show_path : bool, default=True Draw the ideal high-symmetry path. show_matched_kpoints : bool, default=True Highlight sampled k-points whose distance to the path is within the matching tolerance. points_per_seg : int, default=40 Number of interpolation points per path segment for the ideal path. path_match_tol : float, optional Distance tolerance used when highlighting mesh points near the drawn path. extend : bool, default=False Draw translated copies of the sampled k-mesh. extend_copies : int or iterable[int], optional Number of reciprocal-cell copies per direction. In 2D, ``extend_copies=1`` includes the first shell around the first Brillouin zone and ``extend_copies=2`` includes the second shell as well. show_background_bz : bool, default=False Draw translated Brillouin-zone copies behind the mesh. hs_plot : {"none", "markers", "labels", "both"}, default="markers" Whether to draw exact high-symmetry markers, labels, or both. legend_kwargs : dict, optional Extra keyword arguments passed to ``axis.legend``. **kwargs Additional style overrides forwarded to ``plot_high_symmetry_points``. """ kwargs.setdefault("figsize", (5.5, 5.5)) return plot_high_symmetry_points(self.lattice, **kwargs)
[docs] def subsystem( self, sites : List[int], *, show_boundary : bool = True, **kwargs, ) -> Tuple[Figure, Axes]: """ Plot a single subsystem with its boundary highlighted. Parameters ---------- sites : list of int Site indices in the subsystem. show_boundary : bool, default=True If True, highlight the boundary bonds crossing A/B. **kwargs Passed to plot_regions. Returns ------- fig, ax : Figure, Axes Examples -------- >>> lattice.plot.subsystem([0, 1, 4, 5]) >>> lattice.plot.subsystem(range(8), show_bonds=True) """ kwargs.setdefault("figsize", (5.0, 5.0)) kwargs.setdefault("show_bonds", True) # Compute boundary if requested title = kwargs.pop("title", None) if show_boundary and hasattr(self.lattice, 'regions'): dA = self.lattice.regions.subsystem_boundary(sites) if title is None: title = f"Subsystem: |A|={len(sites)}, ∂A={dA}" elif title is None: title = f"Subsystem: |A|={len(sites)}" return plot_regions(self.lattice, {"A": list(sites)}, title=title, **kwargs)
[docs] def sweep( self, direction : Optional[str] = None, *, rectangular : bool = False, max_panels : int = 6, figsize : Optional[Tuple[float, float]] = None, **kwargs, ) -> Tuple[Figure, np.ndarray]: """ Plot subsystem sweep showing cuts with different boundary sizes. Creates a grid of subplots showing subsystems grouped by ∂A. Parameters ---------- direction : str, optional Direction for sweep ('x', 'y', 'z'). Creates full-width cuts. rectangular : bool, default=False If True and direction is None, use rectangular subsystems (various shapes). If False, use lexicographic sweep (sequential site addition). max_panels : int, default=6 Maximum number of panels to show. figsize : tuple, optional Figure size. Auto-computed if None. **kwargs Passed to plot_regions for each panel. Returns ------- fig, axes : Figure, ndarray of Axes Examples -------- >>> lattice.plot.sweep(rectangular=True) # Various rectangular shapes >>> lattice.plot.sweep(direction='x') # Full-width column cuts >>> lattice.plot.sweep() # Lexicographic sweep """ import matplotlib.pyplot as plt # Get sweep data (logic is now in sweep_subsystems) by_dA = self.lattice.regions.sweep_subsystems( direction=direction, rectangular=rectangular ) # Collect one representative per dA panels = [] for dA in sorted(by_dA.keys()): # Pick middle-sized subsystem as representative subs = by_dA[dA] subs_sorted = sorted(subs, key=len) rep = subs_sorted[len(subs_sorted) // 2] panels.append((dA, rep)) if len(panels) >= max_panels: break n = len(panels) if n == 0: raise ValueError("No subsystems generated") # Grid layout - compact ncols = min(n, 3) nrows = (n + ncols - 1) // ncols if figsize is None: figsize = (2.8 * ncols, 2.5 * nrows) fig, axes = plt.subplots(nrows, ncols, figsize=figsize) axes = np.atleast_1d(axes).flatten() # Plot each panel with clean defaults kwargs.setdefault("show_bonds", True) kwargs.setdefault("show_legend", False) kwargs.setdefault("show_labels", False) kwargs.setdefault("show_system", True) kwargs.setdefault("show_axes", False) kwargs.setdefault("marker_size", 40) kwargs.setdefault("tight_layout", False) # We do our own for i, (dA, sites) in enumerate(panels): ax = axes[i] plot_regions( self.lattice, {"A": sites}, ax=ax, title=f"∂A={dA}, |A|={len(sites)}", title_kwargs={"fontsize": 10}, **kwargs, ) ax.set_aspect('equal', adjustable='box') # Hide unused axes for i in range(n, len(axes)): axes[i].set_visible(False) fig.tight_layout(pad=0.5) return fig, axes
# ------------------------------------------------------------------------------ #! EOF # ------------------------------------------------------------------------------