Source code for EPWpy.plotting.plot_wannier_matplotlib

# plot_isosurfaces_matplotlib.py
# Requires: numpy, matplotlib, scikit-image
# pip install numpy matplotlib scikit-image

from typing import Tuple
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d.art3d import Poly3DCollection, Line3DCollection
from skimage import measure
from EPWpy.default.default_dicts import *

def _grid_coordinates_from_axes(origin, axis_x, axis_y, axis_z, grid_shape):
    """
    Builds 3D arrays X,Y,Z of coordinates for the regular grid described by axes and origin.
    axis_x, axis_y, axis_z are 3-component vectors that multiply integer indices.
    grid_shape is (nx, ny, nz).
    """
    nx, ny, nz = grid_shape
    ix = np.arange(nx, dtype=float)
    iy = np.arange(ny, dtype=float)
    iz = np.arange(nz, dtype=float)
    # Using broadcasting: X[ix,iy,iz] = origin + ix*axis_x + iy*axis_y + iz*axis_z
    X = origin[0] + ix[:, None, None]*axis_x[0] + iy[None, :, None]*axis_y[0] + iz[None, None, :]*axis_z[0]
    Y = origin[1] + ix[:, None, None]*axis_x[1] + iy[None, :, None]*axis_y[1] + iz[None, None, :]*axis_z[1]
    Z = origin[2] + ix[:, None, None]*axis_x[2] + iy[None, :, None]*axis_y[2] + iz[None, None, :]*axis_z[2]
    return X, Y, Z

def _faces_to_poly3d(verts: np.ndarray, faces: np.ndarray):
    """
    Convert marching_cubes vertices and faces to list-of-polygons for Poly3DCollection.
    verts: (M,3) ; faces: (K,3) indices
    Return: list of (3,3) float polygons.
    """
    return [verts[tri] for tri in faces]

def _parse_connections(connections, natoms):
    """
    Accepts connections in several formats:
      - None (no bonds)
      - list/ndarray shape (M,2) with integer pairs (i,j) 0-based
      - list/ndarray shape (M,5) with (i,j,dx,dy,dz) image offsets
    Returns list of tuples (i,j,dx,dy,dz)
    """
    if connections is None:
        return []
    arr = np.asarray(connections)
    if arr.size == 0:
        return []
    if arr.ndim == 1 and arr.shape[0] == 2:
        return [(int(arr[0]), int(arr[1]), 0,0,0)]
    if arr.ndim == 2 and arr.shape[1] == 2:
        return [(int(a), int(b), 0, 0, 0) for a,b in arr]
    if arr.ndim == 2 and arr.shape[1] >= 5:
        return [(int(a), int(b), int(dx), int(dy), int(dz)) for a,b,dx,dy,dz in arr[:, :5]]
    raise ValueError("connections format not recognized. Expect (M,2) or (M,5) or None.")

[docs] def plot_isosurfaces_matplotlib(Data: dict, view: dict = {'pos_iso_frac': 0.10, 'neg_iso_frac': 0.10, 'pos_color': (1.0, 0.6, 0.0), 'neg_color': (0.2, 0.45, 0.95), 'pos_alpha': 0.8, 'neg_alpha': 0.6, 'atom_size_factor': 200.0, 'show': True, 'grid': False} ) -> Tuple[plt.Figure, plt.Axes]: """ Plot two isosurfaces (positive at +pos_iso_frac*max, and negative at -neg_iso_frac*max) plus atoms and connections, using matplotlib 3D. Data expected keys: - scalar_data : 3D numpy array shape (nx,ny,nz) - grid_shape : tuple/list (nx,ny,nz) - spacing : (sx,sy,sz) (not strictly required if axis vectors given) - origin : 3-list - axis_x, axis_y, axis_z : 3-lists (vectors multiplied by grid indices) - atomic_positions : iterable of [Z, charge, x, y, z] (Z int) - connections : optional list/ndarray - optional 'verbosity' etc. Returns (fig, ax) """ default_view = {'pos_iso_frac': 0.10, 'neg_iso_frac': 0.10, 'pos_color': (1.0, 0.6, 0.0), 'neg_color': (0.2, 0.45, 0.95), 'pos_alpha': 0.8, 'neg_alpha': 0.6, 'atom_size_factor': 200.0, 'bond_size':2.0, 'bond_color': (0.8,0.7,0.0), 'depthshade': True, 'show': True, 'grid': False, 'atom_kwargs':{}} for key in view.keys(): default_view.update({f'{key}': view[key]}) pos_iso_frac = default_view['pos_iso_frac'] neg_iso_frac = default_view['neg_iso_frac'] pos_color = default_view['pos_color'] neg_color = default_view['neg_color'] pos_alpha = default_view['pos_alpha'] neg_alpha = default_view['neg_alpha'] atom_size_factor = default_view['atom_size_factor'] show = default_view['show'] grid = default_view['grid'] bond_size = default_view['bond_size'] bond_color = default_view['bond_color'] depthshade = default_view['depthshade'] atom_kwargs = default_view['atom_kwargs'] scalar_data = np.asarray(Data['scalar_data']) grid_shape = tuple(Data['grid_shape']) spacing = Data.get('spacing', (1.0,1.0,1.0)) origin = np.asarray(Data.get('origin', (0.0,0.0,0.0)), dtype=float) axis_x = np.asarray(Data.get('axis_x', (spacing[0],0.0,0.0)), dtype=float) axis_y = np.asarray(Data.get('axis_y', (0.0,spacing[1],0.0)), dtype=float) axis_z = np.asarray(Data.get('axis_z', (0.0,0.0,spacing[2])), dtype=float) atomic_positions = Data.get('atomic_positions', []) connections = Data.get('connections', None) verbosity = Data.get('verbosity', 0) # build real-space grids (X,Y,Z) corresponding to scalar_data indices X, Y, Z = _grid_coordinates_from_axes(origin, axis_x, axis_y, axis_z, grid_shape) # marching cubes expects array in shape (nz, ny, nx)? skimage expects (dim0,dim1,dim2) exactly as your array. # We'll directly pass scalar_data as-is but must also pass proper spacing/step if needed. maxval = np.nanmax(scalar_data) if maxval == 0 or not np.isfinite(maxval): raise ValueError("scalar_data has no finite non-zero values") iso_pos_value = pos_iso_frac * maxval iso_neg_value = -neg_iso_frac * maxval if verbosity: print(f"iso pos {iso_pos_value:.6g}, iso neg {iso_neg_value:.6g}, max {maxval:.6g}") # --- Run marching cubes for positive iso --- # skimage.measure.marching_cubes accepts volume and level; returns verts (in voxel coords) and faces. # We need to convert marching-cubes voxel coordinates into real-space coordinates using origin + index*axis vectors. verts_pos, faces_pos, _, _ = measure.marching_cubes(scalar_data, level=iso_pos_value, spacing=(1.0,1.0,1.0)) verts_neg, faces_neg, _, _ = measure.marching_cubes(scalar_data, level=iso_neg_value, spacing=(1.0,1.0,1.0)) # marching_cubes returns verts in index-space: (i, j, k) where i in [0,nx), etc. # Convert index-space verts -> real-space coordinates: def index_to_real(verts_idx): # verts_idx is (M,3) with coords along (dim0,dim1,dim2) matching scalar_data axis order. # We need to map each vertex (ix,iy,iz) to origin + ix*axis_x + iy*axis_y + iz*axis_z real = np.empty_like(verts_idx, dtype=float) real[:,0] = origin[0] + verts_idx[:,0]*axis_x[0] + verts_idx[:,1]*axis_y[0] + verts_idx[:,2]*axis_z[0] real[:,1] = origin[1] + verts_idx[:,0]*axis_x[1] + verts_idx[:,1]*axis_y[1] + verts_idx[:,2]*axis_z[1] real[:,2] = origin[2] + verts_idx[:,0]*axis_x[2] + verts_idx[:,1]*axis_y[2] + verts_idx[:,2]*axis_z[2] return real verts_pos_real = index_to_real(verts_pos) verts_neg_real = index_to_real(verts_neg) # faces -> polygons polys_pos = _faces_to_poly3d(verts_pos_real, faces_pos) polys_neg = _faces_to_poly3d(verts_neg_real, faces_neg) # --- Setup plot --- fig = plt.figure(figsize=(10,8)) ax = fig.add_subplot(111, projection='3d') ax.set_box_aspect((1,1,1)) # --- atomic positions --- # atomic_positions items assumed [Z, charge, x, y, z] or similar; be defensive ats = np.asarray(atomic_positions) # If shape is Nx5 or (N,*) with coords in last three columns: if ats.size == 0: atom_coords = np.zeros((0,3)) atom_numbers = [] else: if ats.ndim == 1 and ats.size >= 5: atom_numbers = [int(ats[0])] atom_coords = np.array([[float(ats[2]), float(ats[3]), float(ats[4])]]) else: atom_numbers = [int(row[0]) for row in ats] atom_coords = np.array([[float(row[2]), float(row[3]), float(row[4])] for row in ats]) # --- connections / bonds --- conn_parsed = _parse_connections(connections, atom_coords.shape[0]) if len(conn_parsed) > 0 and atom_coords.shape[0] > 0: line_segments = [] for (i,j,dx,dy,dz) in conn_parsed: if i < 0 or j < 0 or i >= atom_coords.shape[0] or j >= atom_coords.shape[0]: continue p_i = atom_coords[i] p_j = atom_coords[j] + dx*axis_x + dy*axis_y + dz*axis_z line_segments.append((p_i, p_j)) lc = Line3DCollection(line_segments, colors= bond_color, linewidths= bond_size) ax.add_collection3d(lc) #### plot atoms if atom_coords.shape[0] > 0: colors = [] sizes = [] for Znum in atom_numbers: colors.append(jmol_colors.get(Znum, (0.5,0.5,0.5))) sizes.append(atomic_radii.get(Znum, 0.8) * atom_size_factor) xs = atom_coords[:,0]; ys = atom_coords[:,1]; zs = atom_coords[:,2] ax.scatter(xs, ys, zs, s=sizes, c=colors, depthshade=depthshade, **atom_kwargs) # Add positive isosurface poly_pos = Poly3DCollection(polys_pos, facecolor=pos_color, linewidth=0.1, alpha=pos_alpha) poly_pos.set_edgecolor((0.2,0.2,0.2,0.05)) ax.add_collection3d(poly_pos) # Add negative isosurface poly_neg = Poly3DCollection(polys_neg, facecolor=neg_color, linewidth=0.1, alpha=neg_alpha) poly_neg.set_edgecolor((0.1,0.1,0.1,0.04)) ax.add_collection3d(poly_neg) # Set limits from grid extents and atoms all_x = [X.min(), X.max()] all_y = [Y.min(), Y.max()] all_z = [Z.min(), Z.max()] if atom_coords.shape[0] > 0: all_x += [float(xs.min()), float(xs.max())] all_y += [float(ys.min()), float(ys.max())] all_z += [float(zs.min()), float(zs.max())] ax.set_xlim(min(all_x), max(all_x)) ax.set_ylim(min(all_y), max(all_y)) ax.set_zlim(min(all_z), max(all_z)) ax.set_xlabel("x") ax.set_ylabel("y") ax.set_zlabel("z") if (grid == False): ax.grid(False) if show: plt.tight_layout() plt.show() return fig, ax