Source code for EPWpy.utilities.display_struct

import numpy as np
import urllib.request as urlreq
from EPWpy.utilities.EPW_util import get_connections
from EPWpy.default.default_dicts import *
from EPWpy.error_handling import error_handler
from EPWpy.plotting.plot_structure_matplotlib import *
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.animation as animation

try:
    from mayavi import mlab
    from mayavi.modules.iso_surface import IsoSurface
    from tvtk.api import tvtk
except ImportError:
    error_handler.error_mayavi()
#    print('The mayavi package not found\nTo visualize crystals in EPWpy, install mayavi')
try:
    import plotly
    import plotly.graph_objs as go
    from dash import Dash, dcc, html, Input, Output, callback
    import dash_bio as dashbio
    from dash_bio.utils import xyz_reader
except ImportError:
    error_handler.error_dash()
#    print('Dash-bio not found\nDash-bio not needed unless you want to use molecular view (not-recommended)')



[docs] def get_preset_view(): preset={ 'resolution': 400, 'ao': 0.1, 'outline': 1, 'atomScale': 0.25, 'relativeAtomScale': 0.33, 'bonds': True } return(preset)
[docs] def read_xyz(): with open('xyz','r') as f: data = f.read()#.replace('\n', '') return(data)
[docs] def display_atoms(atom_pos): plotly.offline.init_notebook_mode() # Configure the trace. trace = go.Scatter3d( x=atom_pos[:,0], # <-- Put your data instead y=atom_pos[:,1], # <-- Put your data instead z=atom_pos[:,2], # <-- Put your data instead mode='markers', marker={ 'size': 10, 'opacity': 0.8, } ) # Configure the layout. layout = go.Layout( margin={'l': 0, 'r': 0, 'b': 0, 't': 0} ) data = [trace] plot_figure = go.Figure(data=data, layout=layout) # Render the plot. plotly.offline.iplot(plot_figure)
#plot_figure.show()
[docs] def display_molecule(view=None):#atom_pos): if view is None: view = {} plotly.offline.init_notebook_mode() app = Dash(__name__) data = read_xyz() #print(data) if (len(view) == 0): set_view = get_preset_view() else: set_view = view data = xyz_reader.read_xyz(datapath_or_datastring=data, is_datafile=False) app.layout = html.Div([ dcc.Dropdown( id='default-speck-preset-views', options=[ {'label': 'Ball and stick', 'value': 'stickball'}, {'label': 'Default', 'value': 'default'} ], value='stickball' ), dashbio.Speck( id='default-speck', data=data, view=set_view #presetView = 'stickball' ), ]) #print(data) #app.run(debug=True) #app.run(jupyter_mode="external",port=1234) return(app)
#try: # @callback( # Output('default-speck', 'presetView'), # Input('default-speck-preset-views', 'value##')) #except NameError: # pass
[docs] def update_preset_view(preset_name): return preset_name
[docs] def display_crystal(data, view=None, bond_length=3.5): if view is None: view = {} """ Displays a crystal structure using the provided data and visualization parameters. Args: data (dict): The crystal structure data, typically in the form of an atomic structure object or a compatible data format. view (dict, optional): A dictionary specifying the viewing parameters, such as camera angle, zoom level, or rendering style. Defaults to an empty dictionary. bond_length (float, optional): The maximum bond length threshold for visualizing atomic bonds. Defaults to 3.5. Returns: None: Displays the crystal structure visualization. Example: >>> display_crystal(my_crystal_data, view={'angle': 45}, bond_length=4.0) """ # Take the data atomic_positions = data['positions'] mat = data['mat'] connections,_ = get_connections(atomic_positions[:,0],atomic_positions[:,1],atomic_positions[:,2], bond_length) #print(connections) bond_color = (0.8,0.8,0) if ('bond_color' in view.keys()): bond_color = view['bond_color'] bond_radius = 0.15 if ('bond_radius' in view.keys()): bond_radius = view['bond_radius'] if ('verbosity' in data.keys()): verbosity = data['verbosity'] # Initializing mayavi with a backend if ('in_notebook' in view.keys()): if('backend' in view.keys()): mlab.init_notebook(backend = view['backend']) else: mlab.init_notebook() # Initialize figure mlab.figure('EPWpy',bgcolor=(1,1,1),size=(800, 600)) # Plot atomic positions with jmol colors and atomic radius-based size length = len(np.array(atomic_positions)) X = [] Y = [] Z = [] for i,atom in enumerate(atomic_positions): x, y, z = atom # Get atomic number atom_number = atomic_number.get(mat[i], 12) #print(atom) # Get jmol color for the atomic number color = jmol_colors.get(atom_number, (0.5, 0.5, 0.5)) # Default to gray if not in map # Get atomic radius (if available) and adjust the size accordingly radius = atomic_radii.get(atom_number, 50*pm2bohr) # Default to 50 pm if not in the map size = (radius / 2) # Scale size for visibility # Print atomic info for the user #if (verbosity > 2): # print(f"Atom: {atom_number}, Position: ({x}, {y}, {z}), Radius: {radius} bohr") X.append(x) Y.append(y) Z.append(z) src = mlab.points3d(x, y, z, color=color, scale_factor=size) src = mlab.points3d(X, Y, Z, color=(0,0,0), scale_factor=0.0)#size) src.mlab_source.dataset.lines = np.array(connections) tube = mlab.pipeline.tube(src, tube_radius = bond_radius) tube.filter.radius_factor = 1. mlab.pipeline.surface(tube, color= bond_color) if ('in_notebook' in view.keys()): return(src)#mlab.points3d(x, y, z, color=color, scale_factor=size)) else: mlab.show()
[docs] def display_crystal_matplotlib( data, view=None, point_size=100, wrap=False, show_labels=False, bond_length=None, scale=1.20, min_dist=0.10, depthshade=False, bond_size=3.0, show_bond_length=False, grid=False, show=False ): if view is None: view = {} #title = data['title'] lattice = data['lattice_vec'] elems = data['atomic_species'] counts = data['natoms'] coord_type = 'cartesian' positions = data['positions'] tags = data['mat'] atom_size_factor = point_size if ('grid' in view.keys()): grid = view['grid'] if ('show_labels' in view.keys()): show_labels = view['show_labels'] if ('bond_size' in view.keys()): bond_size = view['bond_size'] if ('depthshade' in view.keys()): depthshade = view['depthshade'] atom_size_factor = 1000 if ('atom_size_factor' in view.keys()): atom_size_factor = view['atom_size_factor'] bond_color = (0.8,0.8,0) if ('bond_color' in view.keys()): bond_color = view['bond_color'] if ('wrap' in view.keys()): wrap = view['wrap'] if ('scale' in view.keys()): scale = view['scale'] if ('show_bond_length' in view.keys()): show_bond_length = view['show_bond_length'] if ('show_labels' in view.keys()): show_labels = view['show_labels'] if ('bond_length' in view.keys()): bond_length = view['bond_length'] if ('min_dist' in view.keys()): bond_length = view['min_dist'] # ensure fractional positions if needed (for wrapping), then cartesian for plotting if coord_type.lower().startswith("d"): fracs = [wrap_frac(p) if wrap else p[:] for p in positions] carts = [frac_to_cart(fr, lattice) for fr in fracs] else: carts = [p[:] for p in positions] # already cartesian if wrap: # Convert to frac, wrap, back to cart using your Cramer's rule A = [[lattice[0][0], lattice[1][0], lattice[2][0]], [lattice[0][1], lattice[1][1], lattice[2][1]], [lattice[0][2], lattice[1][2], lattice[2][2]]] detA = mat_det_3(A) if abs(detA) < 1e-12: raise ValueError("Singular lattice for wrapping.") def invA_mul(r): (a,b,c),(d,e,f),(g,h,i) = A adj = [ [ e*i - f*h, c*h - b*i, b*f - c*e], [ f*g - d*i, a*i - c*g, c*d - a*f], [ d*h - e*g, b*g - a*h, a*e - b*d], ] num = [adj[0][0]*r[0] + adj[0][1]*r[1] + adj[0][2]*r[2], adj[1][0]*r[0] + adj[1][1]*r[1] + adj[1][2]*r[2], adj[2][0]*r[0] + adj[2][1]*r[1] + adj[2][2]*r[2]] return [num[0]/detA, num[1]/detA, num[2]/detA] fracs = [wrap_frac(invA_mul(r)) for r in carts] carts = [frac_to_cart(fr, lattice) for fr in fracs] # If fracs not defined (cartesian input with no wrap), compute fracs for bonding anyway if coord_type.lower().startswith("d"): fracs_for_bond = fracs else: # compute fractional coords for cart positions (not wrapped) fracs_for_bond = [cart_to_frac(r, lattice) for r in carts] # Detect bonds (minimum-image) bonds, distance = get_connections(positions[:,0],positions[:,1],positions[:,2], bond_length, min_dist) #detect_bonds(fracs_for_bond, tags, lattice, bond_length=bond_length, scale=scale, min_dist=min_dist) # Collect points for limits: cell vertices + atoms verts = unit_cell_vertices(lattice) all_pts = verts + carts # Prepare plot fig = plt.figure(figsize=(9, 8)) ax = fig.add_subplot(111, projection='3d') #ax.set_title(title) # Draw cell draw_unit_cell(ax, lattice, lw=1.5) # Group atoms by element for legend/color by_elem = {} for pos, el in zip(carts, tags): by_elem.setdefault(el, []).append(pos) # Plot atoms for i, el in enumerate(elems): pts = by_elem.get(el, []) if not pts: continue xs = [p[0] for p in pts] ys = [p[1] for p in pts] zs = [p[2] for p in pts] atom_number = atomic_number.get(el, 12) #print(atom) # Get jmol color for the atomic number color = jmol_colors.get(atom_number, (0.5, 0.5, 0.5)) size = atomic_radii.get(atom_number, 0.8) * atom_size_factor ax.scatter(xs, ys, zs, s=size, depthshade = depthshade, label=f"{el} ({len(pts)})", marker='o', color = color) if show_labels and len(pts) <= 50: for p in pts: ax.text(p[0], p[1], p[2], el, fontsize=8, ha='center', va='center') # Draw bonds: thin gray for all, red thick for Si-Si (if highlight_si) # for (i, j, dx, dy, dz, dist) in bonds: t = 0 for (i, j) in bonds: p_i = carts[i] # compute image of j in Cartesian: j + dx*a + dy*b + dz*c a, b, c = lattice[0], lattice[1], lattice[2] p_j = carts[j] #add(carts[j], [dx*a[0] + dy*b[0] + dz*c[0], # dx*a[1] + dy*b[1] + dz*c[1], # dx*a[2] + dy*b[2] + dz*c[2]]) # draw base bond in light gray #if (np.linalg.norm([dx,dy,dz]) == 0.0): #print('dist',dist,p_i,p_j) dist = distance[t] t +=1 ax.plot([p_i[0], p_j[0]], [p_i[1], p_j[1]], [p_i[2], p_j[2]], linewidth=bond_size, color=bond_color, zorder=0) if (show_bond_length): xm = 0.5*(p_i[0] + p_j[0]); ym = 0.5*(p_i[1] + p_j[1]); zm = 0.5*(p_i[2] + p_j[2]) ax.text(xm, ym, zm, f"{dist:.3f}Å", color='red', fontsize=7, ha='center', va='center') # highlight Si-Si special """ if highlight_si and tags[i].strip().lower() == 'si' and tags[j].strip().lower() == 'si': ax.plot([p_i[0], p_j[0]], [p_i[1], p_j[1]], [p_i[2], p_j[2]], linewidth=2.6, color='red', zorder=1) # label midpoint with distance xm = 0.5*(p_i[0] + p_j[0]); ym = 0.5*(p_i[1] + p_j[1]); zm = 0.5*(p_i[2] + p_j[2]) ax.text(xm, ym, zm, f"{dist:.3f}Å", color='red', fontsize=7, ha='center', va='center') # write bonds to file #with open("bonds.txt", "w", encoding="utf-8") as f: # f.write("# i j el_i el_j dx dy dz distance_A\n") # for (i, j, dx, dy, dz, dist) in bonds: # f.write(f"{i+1} {j+1} {tags[i]} {tags[j]} {dx} {dy} {dz} {dist:.6f}\n") #print(f"Wrote {len(bonds)} bonds to bonds.txt") """ ax.grid(grid) auto_equal_limits(ax, all_pts) ax.set_xlabel("x (Å)") ax.set_ylabel("y (Å)") ax.set_zlabel("z (Å)") ax.legend(loc="upper left", bbox_to_anchor=(0.02, 0.98)) plt.tight_layout() if (show): plt.show() else: return(ax)
[docs] def display_crystal_phonon( data, view=None, point_size=120, wrap=True, show_labels=True, bond_length=None, frame_interval=10, scale=1.20, min_dist=0.10, show_bond_length=True ): if view is None: view = {} #title = data['title'] lattice = data['lattice_vec'] elems = data['atomic_species'] counts = data['natoms'] coord_type = 'cartesian' positions = data['positions'] tags = data['mat'] forces = data['forces'] quiver_length = 15.0 atom_size_factor = 100 if ('show_labels' in view.keys()): show_labels = view['show_labels'] if ('atom_size_factor' in view.keys()): atom_size_factor = view['atom_size_factor'] bond_color = (0.8,0.8,0) if ('bond_color' in view.keys()): bond_color = view['bond_color'] if ('wrap' in view.keys()): wrap = view['wrap'] if ('scale' in view.keys()): scale = view['scale'] if ('show_bond_length' in view.keys()): show_bond_length = view['show_bond_length'] if ('show_labels' in view.keys()): show_labels = view['show_labels'] if ('bond_length' in view.keys()): bond_length = view['bond_length'] if ('min_dist' in view.keys()): bond_length = view['min_dist'] if ('quiver_length' in view.keys()): quiver_length = view['quiver_length'] # ensure fractional positions if needed (for wrapping), then cartesian for plotting if coord_type.lower().startswith("d"): fracs = [wrap_frac(p) if wrap else p[:] for p in positions] carts = [frac_to_cart(fr, lattice) for fr in fracs] else: carts = [p[:] for p in positions] # already cartesian if wrap: # Convert to frac, wrap, back to cart using your Cramer's rule A = [[lattice[0][0], lattice[1][0], lattice[2][0]], [lattice[0][1], lattice[1][1], lattice[2][1]], [lattice[0][2], lattice[1][2], lattice[2][2]]] detA = mat_det_3(A) if abs(detA) < 1e-12: raise ValueError("Singular lattice for wrapping.") def invA_mul(r): (a,b,c),(d,e,f),(g,h,i) = A adj = [ [ e*i - f*h, c*h - b*i, b*f - c*e], [ f*g - d*i, a*i - c*g, c*d - a*f], [ d*h - e*g, b*g - a*h, a*e - b*d], ] num = [adj[0][0]*r[0] + adj[0][1]*r[1] + adj[0][2]*r[2], adj[1][0]*r[0] + adj[1][1]*r[1] + adj[1][2]*r[2], adj[2][0]*r[0] + adj[2][1]*r[1] + adj[2][2]*r[2]] return [num[0]/detA, num[1]/detA, num[2]/detA] fracs = [wrap_frac(invA_mul(r)) for r in carts] carts = [frac_to_cart(fr, lattice) for fr in fracs] # If fracs not defined (cartesian input with no wrap), compute fracs for bonding anyway if coord_type.lower().startswith("d"): fracs_for_bond = fracs else: # compute fractional coords for cart positions (not wrapped) fracs_for_bond = [cart_to_frac(r, lattice) for r in carts] # Detect bonds (minimum-image) bonds, distance = get_connections(positions[:,0],positions[:,1],positions[:,2], bond_length, min_dist) #detect_bonds(fracs_for_bond, tags, lattice, bond_length=bond_length, scale=scale, min_dist=min_dist) # Collect points for limits: cell vertices + atoms verts = unit_cell_vertices(lattice) all_pts = verts + carts # Prepare plot fig = plt.figure(figsize=(9, 8)) ax = fig.add_subplot(111, projection='3d') #ax.set_title(title) # Draw cell draw_unit_cell(ax, lattice, lw=1.5) # Group atoms by element for legend/color by_elem = {} for pos, el in zip(carts, tags): by_elem.setdefault(el, []).append(pos) # Plot atoms for i, el in enumerate(elems): pts = by_elem.get(el, []) if not pts: continue xs = [p[0] for p in pts] ys = [p[1] for p in pts] zs = [p[2] for p in pts] atom_number = atomic_number.get(el, 12) #print(atom) # Get jmol color for the atomic number color = jmol_colors.get(atom_number, (0.5, 0.5, 0.5)) size = atomic_radii.get(atom_number, 0.8) * atom_size_factor ax.scatter(xs, ys, zs, s=size, depthshade=True, label=f"{el} ({len(pts)})", marker='o',color = color) if show_labels and len(pts) <= 50: for p in pts: ax.text(p[0], p[1], p[2], el, fontsize=8, ha='center', va='center') ax.quiver(positions[:,0],positions[:,1],positions[:,2],forces[:,0],forces[:,1],forces[:,2],length = quiver_length) # Draw bonds: thin gray for all, red thick for Si-Si (if highlight_si) # for (i, j, dx, dy, dz, dist) in bonds: t = 0 for (i, j) in bonds: p_i = carts[i] # compute image of j in Cartesian: j + dx*a + dy*b + dz*c a, b, c = lattice[0], lattice[1], lattice[2] p_j = carts[j] #add(carts[j], [dx*a[0] + dy*b[0] + dz*c[0], # dx*a[1] + dy*b[1] + dz*c[1], # dx*a[2] + dy*b[2] + dz*c[2]]) # draw base bond in light gray #if (np.linalg.norm([dx,dy,dz]) == 0.0): #print('dist',dist,p_i,p_j) dist = distance[t] t +=1 ax.plot([p_i[0], p_j[0]], [p_i[1], p_j[1]], [p_i[2], p_j[2]], linewidth=3.0, color=bond_color, zorder=0) if (show_bond_length): xm = 0.5*(p_i[0] + p_j[0]); ym = 0.5*(p_i[1] + p_j[1]); zm = 0.5*(p_i[2] + p_j[2]) ax.text(xm, ym, zm, f"{dist:.3f}Å", color='red', fontsize=7, ha='center', va='center') auto_equal_limits(ax, all_pts) ax.set_xlabel("x (Å)") ax.set_ylabel("y (Å)") ax.set_zlabel("z (Å)") ax.legend(loc="upper left", bbox_to_anchor=(0.02, 0.98)) #plt.tight_layout() #plt.show() return(ax)
[docs] def animate_crystal_phonon( data, view=None, alpha=2.0, point_size=120, wrap=True, show_labels=True, bond_length=None, scale=1.20, frame_interval=10, min_dist=0.10, show_bond_length=True ): if view is None: view = {} #title = data['title'] lattice = data['lattice_vec'] elems = data['atomic_species'] counts = data['natoms'] coord_type = 'cartesian' positions = data['positions'] tags = data['mat'] forces = data['forces'] tot_points = 20 quiver_length = 15.0 atom_size_factor = 100 if ('tot_points' in view.keys()): tot_points = view['tot_points'] new_positions = np.zeros((tot_points, len(positions[:,0]),len(positions[0,:])),dtype = float) new_positions[0,:,:] = positions[:,:] sine_f = np.sin(np.linspace(0,6.28,tot_points)) q_length = [] q_length.append(0.0) if ('quiver_length' in view.keys()): quiver_length = view['quiver_length'] for i in range(tot_points-1): new_positions[i+1,:,:] = positions[:,:] + forces[:,:]*alpha*sine_f[i+1] q_length.append(quiver_length*sine_f[i+1]) if ('atom_size_factor' in view.keys()): atom_size_factor = view['atom_size_factor'] interval = 10 if ('interval' in view.keys()): interval = view['interval'] if ('show_labels' in view.keys()): show_labels = view['show_labels'] bond_color = (0.8,0.8,0) if ('bond_color' in view.keys()): bond_color = view['bond_color'] if ('wrap' in view.keys()): wrap = view['wrap'] if ('scale' in view.keys()): scale = view['scale'] if ('show_bond_length' in view.keys()): show_bond_length = view['show_bond_length'] if ('show_labels' in view.keys()): show_labels = view['show_labels'] if ('bond_length' in view.keys()): bond_length = view['bond_length'] if ('min_dist' in view.keys()): bond_length = view['min_dist'] bold_lines = [] # ensure fractional positions if needed (for wrapping), then cartesian for plotting if coord_type.lower().startswith("d"): fracs = [wrap_frac(p) if wrap else p[:] for p in positions] carts = [frac_to_cart(fr, lattice) for fr in fracs] else: carts = [p[:] for p in positions] # already cartesian if wrap: # Convert to frac, wrap, back to cart using your Cramer's rule A = [[lattice[0][0], lattice[1][0], lattice[2][0]], [lattice[0][1], lattice[1][1], lattice[2][1]], [lattice[0][2], lattice[1][2], lattice[2][2]]] detA = mat_det_3(A) if abs(detA) < 1e-12: raise ValueError("Singular lattice for wrapping.") def invA_mul(r): (a,b,c),(d,e,f),(g,h,i) = A adj = [ [ e*i - f*h, c*h - b*i, b*f - c*e], [ f*g - d*i, a*i - c*g, c*d - a*f], [ d*h - e*g, b*g - a*h, a*e - b*d], ] num = [adj[0][0]*r[0] + adj[0][1]*r[1] + adj[0][2]*r[2], adj[1][0]*r[0] + adj[1][1]*r[1] + adj[1][2]*r[2], adj[2][0]*r[0] + adj[2][1]*r[1] + adj[2][2]*r[2]] return [num[0]/detA, num[1]/detA, num[2]/detA] fracs = [wrap_frac(invA_mul(r)) for r in carts] carts = [frac_to_cart(fr, lattice) for fr in fracs] # If fracs not defined (cartesian input with no wrap), compute fracs for bonding anyway if coord_type.lower().startswith("d"): fracs_for_bond = fracs else: # compute fractional coords for cart positions (not wrapped) fracs_for_bond = [cart_to_frac(r, lattice) for r in carts] # Detect bonds (minimum-image) bonds, distance = get_connections(positions[:,0],positions[:,1],positions[:,2], bond_length, min_dist) #detect_bonds(fracs_for_bond, tags, lattice, bond_length=bond_length, scale=scale, min_dist=min_dist) # Collect points for limits: cell vertices + atoms verts = unit_cell_vertices(lattice) all_pts = verts + carts # Prepare plot fig = plt.figure(figsize=(9, 8)) ax = fig.add_subplot(111, projection='3d') #ax.set_title(title) # Draw cell draw_unit_cell(ax, lattice, lw=1.5) # Group atoms by element for legend/color by_elem = {} for pos, el in zip(carts, tags): by_elem.setdefault(el, []).append(pos) # Plot atoms scatter = [] for i, el in enumerate(elems): pts = by_elem.get(el, []) if not pts: continue xs = [p[0] for p in pts] ys = [p[1] for p in pts] zs = [p[2] for p in pts] atom_number = atomic_number.get(el, 12) #print(atom) # Get jmol color for the atomic number color = jmol_colors.get(atom_number, (0.5, 0.5, 0.5)) size = atomic_radii.get(atom_number, 0.8) * atom_size_factor scatts = ax.scatter(xs, ys, zs, s=size, depthshade=True, label=f"{el} ({len(pts)})", marker='o',color = color) scatter.append(scatts) if show_labels and len(pts) <= 50: for p in pts: ax.text(p[0], p[1], p[2], el, fontsize=8, ha='center', va='center') quiv = ax.quiver(positions[:,0],positions[:,1],positions[:,2],forces[:,0],forces[:,1],forces[:,2],length = q_length[0]) t = 0 for (i, j) in bonds: p_i = carts[i] # compute image of j in Cartesian: j + dx*a + dy*b + dz*c a, b, c = lattice[0], lattice[1], lattice[2] p_j = carts[j] #add(carts[j], [dx*a[0] + dy*b[0] + dz*c[0], # dx*a[1] + dy*b[1] + dz*c[1], # dx*a[2] + dy*b[2] + dz*c[2]]) # draw base bond in light gray #if (np.linalg.norm([dx,dy,dz]) == 0.0): #print('dist',dist,p_i,p_j) dist = distance[t] t +=1 line, = ax.plot([p_i[0], p_j[0]], [p_i[1], p_j[1]], [p_i[2], p_j[2]], linewidth=3.0, color=bond_color, zorder=0) #line.append(lines) bold_lines.append(line) if (show_bond_length): xm = 0.5*(p_i[0] + p_j[0]); ym = 0.5*(p_i[1] + p_j[1]); zm = 0.5*(p_i[2] + p_j[2]) ax.text(xm, ym, zm, f"{dist:.3f}Å", color='red', fontsize=7, ha='center', va='center') def update(frame): nonlocal quiv, bold_lines # print('frame',frame) # print(positions,new_positions) positions[:,:] = new_positions[int(frame),:,:] if coord_type.lower().startswith("d"): fracs = [wrap_frac(p) if wrap else p[:] for p in positions] carts = [frac_to_cart(fr, lattice) for fr in fracs] else: carts = [p[:] for p in positions] # already cartesian if wrap: # Convert to frac, wrap, back to cart using your Cramer's rule A = [[lattice[0][0], lattice[1][0], lattice[2][0]], [lattice[0][1], lattice[1][1], lattice[2][1]], [lattice[0][2], lattice[1][2], lattice[2][2]]] detA = mat_det_3(A) if abs(detA) < 1e-12: raise ValueError("Singular lattice for wrapping.") def invA_mul(r): (a,b,c),(d,e,f),(g,h,i) = A adj = [ [ e*i - f*h, c*h - b*i, b*f - c*e], [ f*g - d*i, a*i - c*g, c*d - a*f], [ d*h - e*g, b*g - a*h, a*e - b*d], ] num = [adj[0][0]*r[0] + adj[0][1]*r[1] + adj[0][2]*r[2], adj[1][0]*r[0] + adj[1][1]*r[1] + adj[1][2]*r[2], adj[2][0]*r[0] + adj[2][1]*r[1] + adj[2][2]*r[2]] return [num[0]/detA, num[1]/detA, num[2]/detA] fracs = [wrap_frac(invA_mul(r)) for r in carts] carts = [frac_to_cart(fr, lattice) for fr in fracs] # If fracs not defined (cartesian input with no wrap), compute fracs for bonding anyway if coord_type.lower().startswith("d"): fracs_for_bond = fracs else: # compute fractional coords for cart positions (not wrapped) fracs_for_bond = [cart_to_frac(r, lattice) for r in carts] bonds, distance = get_connections(positions[:,0],positions[:,1],positions[:,2], bond_length, min_dist) # Collect points for limits: cell vertices + atoms verts = unit_cell_vertices(lattice) all_pts = verts + carts # Group atoms by element for legend/color by_elem = {} for pos, el in zip(carts, tags): by_elem.setdefault(el, []).append(pos) # Plot atoms for i, el in enumerate(elems): pts = by_elem.get(el, []) if not pts: continue xs = [p[0] for p in pts] ys = [p[1] for p in pts] zs = [p[2] for p in pts] atom_number = atomic_number.get(el, 12) #print(atom) # Get jmol color for the atomic number color = jmol_colors.get(atom_number, (0.5, 0.5, 0.5)) scatter[i]._offsets3d = (xs, ys, zs) if show_labels and len(pts) <= 50: for p in pts: ax.text(p[0], p[1], p[2], el, fontsize=8, ha='center', va='center') if quiv: quiv.remove() quiv = ax.quiver(positions[:,0],positions[:,1],positions[:,2],forces[:,0],forces[:,1],forces[:,2],length = q_length[int(frame)]) # Draw bonds: thin gray for all, red thick for Si-Si (if highlight_si) # # for (i, j, dx, dy, dz, dist) in bonds: t = 0 for line in bold_lines: line.remove() bold_lines = [] for (i, j) in bonds: p_i = carts[i] # compute image of j in Cartesian: j + dx*a + dy*b + dz*c a, b, c = lattice[0], lattice[1], lattice[2] p_j = carts[j] #add(carts[j], [dx*a[0] + dy*b[0] + dz*c[0], # dx*a[1] + dy*b[1] + dz*c[1], # dx*a[2] + dy*b[2] + dz*c[2]]) # draw base bond in light gray #if (np.linalg.norm([dx,dy,dz]) == 0.0): #print('dist',dist,p_i,p_j) dist = distance[t] t +=1 line, = ax.plot([p_i[0], p_j[0]], [p_i[1], p_j[1]], [p_i[2], p_j[2]], linewidth=3.0, color=bond_color, zorder=0) bold_lines.append(line) if (show_bond_length): xm = 0.5*(p_i[0] + p_j[0]); ym = 0.5*(p_i[1] + p_j[1]); zm = 0.5*(p_i[2] + p_j[2]) ax.text(xm, ym, zm, f"{dist:.3f}Å", color='red', fontsize=7, ha='center', va='center') return scatter, auto_equal_limits(ax, all_pts) ax.set_xlabel("x (Å)") ax.set_ylabel("y (Å)") ax.set_zlabel("z (Å)") ax.legend(loc="upper left", bbox_to_anchor=(0.02, 0.98)) #plt.tight_layout() #plt.show() ani = animation.FuncAnimation(fig, update, frames=tot_points, interval=frame_interval, blit=False) ani.save("crystal_phonon_with_bonds.gif", writer="pillow") print("Animated crystal phonon visualization with bonds saved as crystal_phonon") # plt.show() return(fig, ani)
# return(ax) if __name__ == '__main__': #app = Dash(__name__) draw_molecule() #app.run(debug=True)