Source code for EPWpy.plotting.plot_polaron_matplotlib

# Matplotlib-only reimplementation of plot_psir_plrn (no Mayavi)
# Requires: numpy, matplotlib, scikit-image
# pip install numpy matplotlib scikit-image

from typing import Tuple, Optional
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401
from mpl_toolkits.mplot3d.art3d import Poly3DCollection, Line3DCollection
from skimage import measure
from EPWpy.default.default_dicts import *


def _prepare_volume(Density, Grid, order_hint="C"):
    """
    Return a 3D numpy array vol shaped like Grid[...,0].shape.
    Accepts Density shaped (nx,ny,nz), (nx,ny,nz,1), or flat (nx*ny*nz,).
    order_hint: "C" (row-major) or "F" (column-major) for reshaping flat arrays.
    """
    vol = np.asarray(Density)
    vol = np.real(vol)        # in case of complex
    vol = np.squeeze(vol)     # drop singleton dims

    # Already 3D?
    if vol.ndim == 3:
        return vol

    # Try to infer shape from Grid
    grid_shape = np.asarray(Grid)[..., 0].shape  # (nx,ny,nz)
    nvox = int(np.prod(grid_shape))

    # Handle flat volume
    if vol.ndim == 1 and vol.size == nvox:
        try:
            return vol.reshape(grid_shape, order=order_hint)
        except Exception:
            # fallback: try the other memory order
            other = "F" if order_hint == "C" else "C"
            return vol.reshape(grid_shape, order=other)

    # Handle (nx,ny,nz,1)
    if vol.ndim == 4 and vol.shape[-1] == 1 and vol.shape[:3] == grid_shape:
        return vol[..., 0]

    raise ValueError(
        f"Density cannot be coerced to 3D. Got shape {vol.shape}; expected "
        f"(nx,ny,nz) compatible with Grid[...,0].shape = {grid_shape}."
    )

def _compute_mat(mat):
    """Replicate your 'matn' grouping logic."""
    mat = list(mat)
    if not mat:
        return np.array([])
    prev = mat[0]
    n = 2
    out = []
    for m in mat:
        if m == prev:
            out.append(n)
        else:
            prev = m
            n += 1
            out.append(n)
    return np.array(out, dtype=float)

def _compute_matn(mat):
    """Replicate your 'matn' grouping logic."""
    mat = list(mat)
    colr = []
    for Znum in mat:
        colr.append(jmol_colors.get(Znum, (0.5,0.5,0.5)))
    return np.array(colr)

def _compute_psize(mat):
    """Replicate your 'matn' grouping logic."""
    mat = list(mat)
    colr = []
    for Znum in mat:
        radius =  atomic_radii.get(Znum, 50*pm2bohr)  # Default to 50 pm if not in the map
        size = (radius / 2.0) 
        colr.append(size)
    print(min(colr),max(colr))
    return np.array(colr)


def _faces_to_poly3d(verts: np.ndarray, faces: np.ndarray):
    """Convert marching-cubes vertices and faces to Poly3D polygons."""
    return [verts[f] for f in faces]

def _auto_iso_level_from_density(Density: np.ndarray) -> float:
    """Pick a reasonable iso level if none is specified."""
    dmin, dmax = float(np.nanmin(Density)), float(np.nanmax(Density))
    if not np.isfinite(dmin) or not np.isfinite(dmax) or dmin == dmax:
        return 0.0
    # If data has both signs, iso at 0; else choose mid-value
    if dmin < 0.0 < dmax:
        return 0.0
    return 0.5 * (dmin + dmax)

def _draw_outline_box(ax, xmin, xmax, ymin, ymax, zmin, zmax, color='k', lw=1.0):
    """Draw a rectangular box outline given extents."""
    # 8 corners
    p000 = np.array([xmin, ymin, zmin])
    p100 = np.array([xmax, ymin, zmin])
    p010 = np.array([xmin, ymax, zmin])
    p110 = np.array([xmax, ymax, zmin])
    p001 = np.array([xmin, ymin, zmax])
    p101 = np.array([xmax, ymin, zmax])
    p011 = np.array([xmin, ymax, zmax])
    p111 = np.array([xmax, ymax, zmax])
    edges = [
        (p000, p100), (p000, p010), (p000, p001),
        (p100, p110), (p100, p101),
        (p010, p110), (p010, p011),
        (p001, p101), (p001, p011),
        (p110, p111), (p101, p111), (p011, p111),
    ]
    for a, b in edges:
        ax.plot([a[0], b[0]], [a[1], b[1]], [a[2], b[2]], color=color, lw=lw)

[docs] def plot_psir_plrn_matplotlib( Data: dict, *, view: dict = {}, iso_level: Optional[float] = None, # if None, pick automatically iso_alpha: float = 0.35, iso_color=(0.9, 0.8, 0.0), # golden-ish point_size: float = 40.0, cmap: str = "viridis", quiver_length: float = 1.0, quiver_normalize: bool = False, figsize=(9, 8), show: bool = True, bond_size: float = 2.0, ) -> Tuple[plt.Figure, plt.Axes]: """ Matplotlib implementation of your psir_plrn visualization. Expected keys in Data: - 'mat' : list/array of labels (used to build matn coloring) - 'x','y','z' : 1D arrays of atom (or node) coordinates (same length) - 'u','v','w' : 1D arrays for vectors (same length as x,y,z) - 'Dense' (unused by the original display; included for compatibility) - 'Grid' : 4D array (..., 3) with real-space coordinates per grid point e.g. Grid[ix,iy,iz,0/1/2] = x/y/z - 'pts' : (N,3) points (used by Mayavi's StructuredGrid) – optional here - 'Density' : 3D scalar field matching Grid[...,0].shape - 'connections' : list/array of (i,j) pairs or (i,j,dx,dy,dz) - optional 'in_notebook', 'backend' (ignored here) Returns (fig, ax). """ mat = Data['mat'] x = np.asarray(Data['x'], dtype=float) y = np.asarray(Data['y'], dtype=float) z = np.asarray(Data['z'], dtype=float) u = np.asarray(Data['u'], dtype=float) v = np.asarray(Data['v'], dtype=float) w = np.asarray(Data['w'], dtype=float) Grid = np.asarray(Data['Grid'], dtype=float) # shape (..., 3) Density = np.asarray(Data['Density'], dtype=float) connections = np.asarray(Data.get('connections', [])) default_view = {'pos_iso_frac': 0.10, 'neg_iso_frac': 0.10, 'pos_color': (1.0, 0.6, 0.0), 'neg_color': (0.2, 0.45, 0.95), 'pos_alpha': 0.8, 'neg_alpha': 0.6, 'quiver_length': 1.0, 'quiver_normalize': False, 'atom_size_factor': 200.0, 'bond_size':2.0, 'bond_color': (0.8,0.7,0.0), 'depthshade': True, 'cmap': '\'viridis\'', 'grid': False, 'figsize':figsize, 'atom_kwargs':{}, 'quiver_args':{}} for key in view.keys(): default_view.update({f'{key}': view[key]}) pos_iso_frac = default_view['pos_iso_frac'] neg_iso_frac = default_view['neg_iso_frac'] pos_color = default_view['pos_color'] neg_color = default_view['neg_color'] pos_alpha = default_view['pos_alpha'] neg_alpha = default_view['neg_alpha'] atom_size_factor = default_view['atom_size_factor'] cmap = default_view['cmap'] grid = default_view['grid'] bond_size = default_view['bond_size'] bond_color = default_view['bond_color'] depthshade = default_view['depthshade'] atom_kwargs = default_view['atom_kwargs'] #iso_level = pos_iso_frac iso_alpha = pos_alpha iso_color= pos_color # golden-ish point_size = atom_size_factor cmap = cmap quiver_length = default_view['quiver_length'] quiver_normalize = default_view['quiver_normalize'] figsize= default_view['figsize'] quiver_args = default_view['quiver_args'] # 1) Figure/Axes fig = plt.figure(figsize=figsize) ax = fig.add_subplot(111, projection='3d') ax.set_title("psir_plrn (Matplotlib)") # 2) Build 'matn' categories like your original matn = _compute_matn(mat) if len(mat) == len(x) else np.arange(len(x), dtype=float) psize = _compute_psize(mat) if len(mat) == len(x) else np.arange(len(x), dtype=float) # 3) Draw connections (bonds/edges) line_segments = [] if connections.size > 0: conns = np.asarray(connections) if conns.ndim == 1: conns = conns.reshape(1, -1) # Accept (i,j) or (i,j,dx,dy,dz); dx,dy,dz are image offsets in the same # coordinate system as Grid (if periodicity is used). use_grid_shift = False # If we have a rectilinear grid, we can infer lattice vectors from Grid # by looking at differences along each axis (rough heuristic). # This is only used when dx,dy,dz are provided. shape_xyz = Grid[..., 0].shape if conns.shape[1] >= 5: use_grid_shift = True # Estimate lattice vectors from the first line along each axis, if possible. # Fall back to zero if shapes are 1 along any axis. lat_x = np.array([0., 0., 0.]) lat_y = np.array([0., 0., 0.]) lat_z = np.array([0., 0., 0.]) if shape_xyz[0] > 1: lat_x = Grid[1, 0, 0, :] - Grid[0, 0, 0, :] if shape_xyz[1] > 1: lat_y = Grid[0, 1, 0, :] - Grid[0, 0, 0, :] if shape_xyz[2] > 1: lat_z = Grid[0, 0, 1, :] - Grid[0, 0, 0, :] for row in conns: if row.shape[0] >= 5: i, j, dx, dy, dz = int(row[0]), int(row[1]), int(row[2]), int(row[3]), int(row[4]) if i < 0 or j < 0 or i >= len(x) or j >= len(x): continue shift = dx * lat_x + dy * lat_y + dz * lat_z if use_grid_shift else np.zeros(3) p = np.array([x[i], y[i], z[i]]) q = np.array([x[j] + shift[0], y[j] + shift[1], z[j] + shift[2]]) line_segments.append((p, q)) else: i, j = int(row[0]), int(row[1]) if i < 0 or j < 0 or i >= len(x) or j >= len(x): continue p = np.array([x[i], y[i], z[i]]) q = np.array([x[j], y[j], z[j]]) line_segments.append((p, q)) if line_segments: lc = Line3DCollection(line_segments, colors= bond_color, linewidths=bond_size) ax.add_collection3d(lc) # 4) Scatter the points colored by matn pts_sc = ax.scatter(x, y, z, c=matn, s=psize*point_size, cmap=cmap, depthshade=depthshade, **atom_kwargs) # 5) Isosurface from Density + Grid # Grid is a rectilinear real-space coordinate array: Grid[ix,iy,iz,0/1/2] = x,y,z # marching_cubes returns vertices in *index* coordinates; we map them to real coords. if iso_level is None: iso_level = _auto_iso_level_from_density(Density) # Run marching cubes in index space #verts_idx, faces, _, _ = measure.marching_cubes(Density, level=iso_level, spacing=(1.0, 1.0, 1.0)) vol = _prepare_volume(Density, Grid, order_hint="F") # XSF-like data are often Fortran-ordered verts_idx, faces, _, _ = measure.marching_cubes(vol, level=iso_level, spacing=(1.0, 1.0, 1.0)) # Convert vertices to real-space using Grid by trilinear interpolation of corner cell. # For rectilinear grids this simpler mapping also works: # x = Grid[ix,0,0,0] + (ix-frac)*dx + ... # but we’ll do a direct map using the axis-aligned assumption found in typical XSF. # If Grid is not orthogonal, we approximate by: # r(ix,iy,iz) ~ origin + ix*ex + iy*ey + iz*ez, # where ex,ey,ez are inferred from first differences along each axis. shape_xyz = Grid[..., 0].shape # (nx,ny,nz) origin = Grid[0, 0, 0, :] ex = Grid[1, 0, 0, :] - Grid[0, 0, 0, :] if shape_xyz[0] > 1 else np.array([0., 0., 0.]) ey = Grid[0, 1, 0, :] - Grid[0, 0, 0, :] if shape_xyz[1] > 1 else np.array([0., 0., 0.]) ez = Grid[0, 0, 1, :] - Grid[0, 0, 0, :] if shape_xyz[2] > 1 else np.array([0., 0., 0.]) verts_real = np.empty_like(verts_idx) verts_real[:, 0] = origin[0] + verts_idx[:, 0] * ex[0] + verts_idx[:, 1] * ey[0] + verts_idx[:, 2] * ez[0] verts_real[:, 1] = origin[1] + verts_idx[:, 0] * ex[1] + verts_idx[:, 1] * ey[1] + verts_idx[:, 2] * ez[1] verts_real[:, 2] = origin[2] + verts_idx[:, 0] * ex[2] + verts_idx[:, 1] * ey[2] + verts_idx[:, 2] * ez[2] polys = _faces_to_poly3d(verts_real, faces) iso_poly = Poly3DCollection(polys, facecolor=iso_color, alpha=iso_alpha, linewidth=0.1) iso_poly.set_edgecolor((0.2, 0.2, 0.2, 0.05)) ax.add_collection3d(iso_poly) # 6) Quiver (u,v,w) at (x,y,z) ax.quiver(x, y, z, u, v, w, length=quiver_length, normalize=quiver_normalize, **quiver_args) # 7) Axes limits and outline box from Grid extents xmin, xmax = float(np.min(Grid[..., 0])), float(np.max(Grid[..., 0])) ymin, ymax = float(np.min(Grid[..., 1])), float(np.max(Grid[..., 1])) zmin, zmax = float(np.min(Grid[..., 2])), float(np.max(Grid[..., 2])) ax.set_xlim(xmin, xmax); ax.set_ylim(ymin, ymax); ax.set_zlim(zmin, zmax) _draw_outline_box(ax, xmin, xmax, ymin, ymax, zmin, zmax, color='k', lw=1.0) ax.set_xlabel("x"); ax.set_ylabel("y"); ax.set_zlabel("z") ax.grid(grid) plt.tight_layout() if show: plt.show() return fig, ax