# Matplotlib-only reimplementation of plot_psir_plrn (no Mayavi)
# Requires: numpy, matplotlib, scikit-image
# pip install numpy matplotlib scikit-image
from typing import Tuple, Optional
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D # noqa: F401
from mpl_toolkits.mplot3d.art3d import Poly3DCollection, Line3DCollection
from skimage import measure
from EPWpy.default.default_dicts import *
def _prepare_volume(Density, Grid, order_hint="C"):
"""
Return a 3D numpy array vol shaped like Grid[...,0].shape.
Accepts Density shaped (nx,ny,nz), (nx,ny,nz,1), or flat (nx*ny*nz,).
order_hint: "C" (row-major) or "F" (column-major) for reshaping flat arrays.
"""
vol = np.asarray(Density)
vol = np.real(vol) # in case of complex
vol = np.squeeze(vol) # drop singleton dims
# Already 3D?
if vol.ndim == 3:
return vol
# Try to infer shape from Grid
grid_shape = np.asarray(Grid)[..., 0].shape # (nx,ny,nz)
nvox = int(np.prod(grid_shape))
# Handle flat volume
if vol.ndim == 1 and vol.size == nvox:
try:
return vol.reshape(grid_shape, order=order_hint)
except Exception:
# fallback: try the other memory order
other = "F" if order_hint == "C" else "C"
return vol.reshape(grid_shape, order=other)
# Handle (nx,ny,nz,1)
if vol.ndim == 4 and vol.shape[-1] == 1 and vol.shape[:3] == grid_shape:
return vol[..., 0]
raise ValueError(
f"Density cannot be coerced to 3D. Got shape {vol.shape}; expected "
f"(nx,ny,nz) compatible with Grid[...,0].shape = {grid_shape}."
)
def _compute_mat(mat):
"""Replicate your 'matn' grouping logic."""
mat = list(mat)
if not mat:
return np.array([])
prev = mat[0]
n = 2
out = []
for m in mat:
if m == prev:
out.append(n)
else:
prev = m
n += 1
out.append(n)
return np.array(out, dtype=float)
def _compute_matn(mat):
"""Replicate your 'matn' grouping logic."""
mat = list(mat)
colr = []
for Znum in mat:
colr.append(jmol_colors.get(Znum, (0.5,0.5,0.5)))
return np.array(colr)
def _compute_psize(mat):
"""Replicate your 'matn' grouping logic."""
mat = list(mat)
colr = []
for Znum in mat:
radius = atomic_radii.get(Znum, 50*pm2bohr) # Default to 50 pm if not in the map
size = (radius / 2.0)
colr.append(size)
print(min(colr),max(colr))
return np.array(colr)
def _faces_to_poly3d(verts: np.ndarray, faces: np.ndarray):
"""Convert marching-cubes vertices and faces to Poly3D polygons."""
return [verts[f] for f in faces]
def _auto_iso_level_from_density(Density: np.ndarray) -> float:
"""Pick a reasonable iso level if none is specified."""
dmin, dmax = float(np.nanmin(Density)), float(np.nanmax(Density))
if not np.isfinite(dmin) or not np.isfinite(dmax) or dmin == dmax:
return 0.0
# If data has both signs, iso at 0; else choose mid-value
if dmin < 0.0 < dmax:
return 0.0
return 0.5 * (dmin + dmax)
def _draw_outline_box(ax, xmin, xmax, ymin, ymax, zmin, zmax, color='k', lw=1.0):
"""Draw a rectangular box outline given extents."""
# 8 corners
p000 = np.array([xmin, ymin, zmin])
p100 = np.array([xmax, ymin, zmin])
p010 = np.array([xmin, ymax, zmin])
p110 = np.array([xmax, ymax, zmin])
p001 = np.array([xmin, ymin, zmax])
p101 = np.array([xmax, ymin, zmax])
p011 = np.array([xmin, ymax, zmax])
p111 = np.array([xmax, ymax, zmax])
edges = [
(p000, p100), (p000, p010), (p000, p001),
(p100, p110), (p100, p101),
(p010, p110), (p010, p011),
(p001, p101), (p001, p011),
(p110, p111), (p101, p111), (p011, p111),
]
for a, b in edges:
ax.plot([a[0], b[0]], [a[1], b[1]], [a[2], b[2]], color=color, lw=lw)
[docs]
def plot_psir_plrn_matplotlib(
Data: dict,
*,
view: dict = {},
iso_level: Optional[float] = None, # if None, pick automatically
iso_alpha: float = 0.35,
iso_color=(0.9, 0.8, 0.0), # golden-ish
point_size: float = 40.0,
cmap: str = "viridis",
quiver_length: float = 1.0,
quiver_normalize: bool = False,
figsize=(9, 8),
show: bool = True,
bond_size: float = 2.0,
) -> Tuple[plt.Figure, plt.Axes]:
"""
Matplotlib implementation of your psir_plrn visualization.
Expected keys in Data:
- 'mat' : list/array of labels (used to build matn coloring)
- 'x','y','z' : 1D arrays of atom (or node) coordinates (same length)
- 'u','v','w' : 1D arrays for vectors (same length as x,y,z)
- 'Dense' (unused by the original display; included for compatibility)
- 'Grid' : 4D array (..., 3) with real-space coordinates per grid point
e.g. Grid[ix,iy,iz,0/1/2] = x/y/z
- 'pts' : (N,3) points (used by Mayavi's StructuredGrid) – optional here
- 'Density' : 3D scalar field matching Grid[...,0].shape
- 'connections' : list/array of (i,j) pairs or (i,j,dx,dy,dz)
- optional 'in_notebook', 'backend' (ignored here)
Returns (fig, ax).
"""
mat = Data['mat']
x = np.asarray(Data['x'], dtype=float)
y = np.asarray(Data['y'], dtype=float)
z = np.asarray(Data['z'], dtype=float)
u = np.asarray(Data['u'], dtype=float)
v = np.asarray(Data['v'], dtype=float)
w = np.asarray(Data['w'], dtype=float)
Grid = np.asarray(Data['Grid'], dtype=float) # shape (..., 3)
Density = np.asarray(Data['Density'], dtype=float)
connections = np.asarray(Data.get('connections', []))
default_view = {'pos_iso_frac': 0.10,
'neg_iso_frac': 0.10,
'pos_color': (1.0, 0.6, 0.0),
'neg_color': (0.2, 0.45, 0.95),
'pos_alpha': 0.8,
'neg_alpha': 0.6,
'quiver_length': 1.0,
'quiver_normalize': False,
'atom_size_factor': 200.0,
'bond_size':2.0,
'bond_color': (0.8,0.7,0.0),
'depthshade': True,
'cmap': '\'viridis\'',
'grid': False,
'figsize':figsize,
'atom_kwargs':{},
'quiver_args':{}}
for key in view.keys():
default_view.update({f'{key}': view[key]})
pos_iso_frac = default_view['pos_iso_frac']
neg_iso_frac = default_view['neg_iso_frac']
pos_color = default_view['pos_color']
neg_color = default_view['neg_color']
pos_alpha = default_view['pos_alpha']
neg_alpha = default_view['neg_alpha']
atom_size_factor = default_view['atom_size_factor']
cmap = default_view['cmap']
grid = default_view['grid']
bond_size = default_view['bond_size']
bond_color = default_view['bond_color']
depthshade = default_view['depthshade']
atom_kwargs = default_view['atom_kwargs']
#iso_level = pos_iso_frac
iso_alpha = pos_alpha
iso_color= pos_color # golden-ish
point_size = atom_size_factor
cmap = cmap
quiver_length = default_view['quiver_length']
quiver_normalize = default_view['quiver_normalize']
figsize= default_view['figsize']
quiver_args = default_view['quiver_args']
# 1) Figure/Axes
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111, projection='3d')
ax.set_title("psir_plrn (Matplotlib)")
# 2) Build 'matn' categories like your original
matn = _compute_matn(mat) if len(mat) == len(x) else np.arange(len(x), dtype=float)
psize = _compute_psize(mat) if len(mat) == len(x) else np.arange(len(x), dtype=float)
# 3) Draw connections (bonds/edges)
line_segments = []
if connections.size > 0:
conns = np.asarray(connections)
if conns.ndim == 1:
conns = conns.reshape(1, -1)
# Accept (i,j) or (i,j,dx,dy,dz); dx,dy,dz are image offsets in the same
# coordinate system as Grid (if periodicity is used).
use_grid_shift = False
# If we have a rectilinear grid, we can infer lattice vectors from Grid
# by looking at differences along each axis (rough heuristic).
# This is only used when dx,dy,dz are provided.
shape_xyz = Grid[..., 0].shape
if conns.shape[1] >= 5:
use_grid_shift = True
# Estimate lattice vectors from the first line along each axis, if possible.
# Fall back to zero if shapes are 1 along any axis.
lat_x = np.array([0., 0., 0.])
lat_y = np.array([0., 0., 0.])
lat_z = np.array([0., 0., 0.])
if shape_xyz[0] > 1:
lat_x = Grid[1, 0, 0, :] - Grid[0, 0, 0, :]
if shape_xyz[1] > 1:
lat_y = Grid[0, 1, 0, :] - Grid[0, 0, 0, :]
if shape_xyz[2] > 1:
lat_z = Grid[0, 0, 1, :] - Grid[0, 0, 0, :]
for row in conns:
if row.shape[0] >= 5:
i, j, dx, dy, dz = int(row[0]), int(row[1]), int(row[2]), int(row[3]), int(row[4])
if i < 0 or j < 0 or i >= len(x) or j >= len(x):
continue
shift = dx * lat_x + dy * lat_y + dz * lat_z if use_grid_shift else np.zeros(3)
p = np.array([x[i], y[i], z[i]])
q = np.array([x[j] + shift[0], y[j] + shift[1], z[j] + shift[2]])
line_segments.append((p, q))
else:
i, j = int(row[0]), int(row[1])
if i < 0 or j < 0 or i >= len(x) or j >= len(x):
continue
p = np.array([x[i], y[i], z[i]])
q = np.array([x[j], y[j], z[j]])
line_segments.append((p, q))
if line_segments:
lc = Line3DCollection(line_segments, colors= bond_color, linewidths=bond_size)
ax.add_collection3d(lc)
# 4) Scatter the points colored by matn
pts_sc = ax.scatter(x, y, z, c=matn, s=psize*point_size, cmap=cmap, depthshade=depthshade, **atom_kwargs)
# 5) Isosurface from Density + Grid
# Grid is a rectilinear real-space coordinate array: Grid[ix,iy,iz,0/1/2] = x,y,z
# marching_cubes returns vertices in *index* coordinates; we map them to real coords.
if iso_level is None:
iso_level = _auto_iso_level_from_density(Density)
# Run marching cubes in index space
#verts_idx, faces, _, _ = measure.marching_cubes(Density, level=iso_level, spacing=(1.0, 1.0, 1.0))
vol = _prepare_volume(Density, Grid, order_hint="F") # XSF-like data are often Fortran-ordered
verts_idx, faces, _, _ = measure.marching_cubes(vol, level=iso_level, spacing=(1.0, 1.0, 1.0))
# Convert vertices to real-space using Grid by trilinear interpolation of corner cell.
# For rectilinear grids this simpler mapping also works:
# x = Grid[ix,0,0,0] + (ix-frac)*dx + ...
# but we’ll do a direct map using the axis-aligned assumption found in typical XSF.
# If Grid is not orthogonal, we approximate by:
# r(ix,iy,iz) ~ origin + ix*ex + iy*ey + iz*ez,
# where ex,ey,ez are inferred from first differences along each axis.
shape_xyz = Grid[..., 0].shape # (nx,ny,nz)
origin = Grid[0, 0, 0, :]
ex = Grid[1, 0, 0, :] - Grid[0, 0, 0, :] if shape_xyz[0] > 1 else np.array([0., 0., 0.])
ey = Grid[0, 1, 0, :] - Grid[0, 0, 0, :] if shape_xyz[1] > 1 else np.array([0., 0., 0.])
ez = Grid[0, 0, 1, :] - Grid[0, 0, 0, :] if shape_xyz[2] > 1 else np.array([0., 0., 0.])
verts_real = np.empty_like(verts_idx)
verts_real[:, 0] = origin[0] + verts_idx[:, 0] * ex[0] + verts_idx[:, 1] * ey[0] + verts_idx[:, 2] * ez[0]
verts_real[:, 1] = origin[1] + verts_idx[:, 0] * ex[1] + verts_idx[:, 1] * ey[1] + verts_idx[:, 2] * ez[1]
verts_real[:, 2] = origin[2] + verts_idx[:, 0] * ex[2] + verts_idx[:, 1] * ey[2] + verts_idx[:, 2] * ez[2]
polys = _faces_to_poly3d(verts_real, faces)
iso_poly = Poly3DCollection(polys, facecolor=iso_color, alpha=iso_alpha, linewidth=0.1)
iso_poly.set_edgecolor((0.2, 0.2, 0.2, 0.05))
ax.add_collection3d(iso_poly)
# 6) Quiver (u,v,w) at (x,y,z)
ax.quiver(x, y, z, u, v, w, length=quiver_length, normalize=quiver_normalize, **quiver_args)
# 7) Axes limits and outline box from Grid extents
xmin, xmax = float(np.min(Grid[..., 0])), float(np.max(Grid[..., 0]))
ymin, ymax = float(np.min(Grid[..., 1])), float(np.max(Grid[..., 1]))
zmin, zmax = float(np.min(Grid[..., 2])), float(np.max(Grid[..., 2]))
ax.set_xlim(xmin, xmax); ax.set_ylim(ymin, ymax); ax.set_zlim(zmin, zmax)
_draw_outline_box(ax, xmin, xmax, ymin, ymax, zmin, zmax, color='k', lw=1.0)
ax.set_xlabel("x"); ax.set_ylabel("y"); ax.set_zlabel("z")
ax.grid(grid)
plt.tight_layout()
if show:
plt.show()
return fig, ax