import numpy as np
import urllib.request as urlreq
from EPWpy.utilities.EPW_util import get_connections
from EPWpy.default.default_dicts import *
from EPWpy.error_handling import error_handler
from EPWpy.plotting.plot_structure_matplotlib import *
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.animation as animation
try:
from mayavi import mlab
from mayavi.modules.iso_surface import IsoSurface
from tvtk.api import tvtk
except ImportError:
error_handler.error_mayavi()
# print('The mayavi package not found\nTo visualize crystals in EPWpy, install mayavi')
try:
import plotly
import plotly.graph_objs as go
from dash import Dash, dcc, html, Input, Output, callback
import dash_bio as dashbio
from dash_bio.utils import xyz_reader
except ImportError:
error_handler.error_dash()
# print('Dash-bio not found\nDash-bio not needed unless you want to use molecular view (not-recommended)')
[docs]
def get_preset_view():
preset={
'resolution': 400,
'ao': 0.1,
'outline': 1,
'atomScale': 0.25,
'relativeAtomScale': 0.33,
'bonds': True
}
return(preset)
[docs]
def read_xyz():
with open('xyz','r') as f:
data = f.read()#.replace('\n', '')
return(data)
[docs]
def display_atoms(atom_pos):
plotly.offline.init_notebook_mode()
# Configure the trace.
trace = go.Scatter3d(
x=atom_pos[:,0], # <-- Put your data instead
y=atom_pos[:,1], # <-- Put your data instead
z=atom_pos[:,2], # <-- Put your data instead
mode='markers',
marker={
'size': 10,
'opacity': 0.8,
}
)
# Configure the layout.
layout = go.Layout(
margin={'l': 0, 'r': 0, 'b': 0, 't': 0}
)
data = [trace]
plot_figure = go.Figure(data=data, layout=layout)
# Render the plot.
plotly.offline.iplot(plot_figure)
#plot_figure.show()
[docs]
def display_molecule(view=None):#atom_pos):
if view is None:
view = {}
plotly.offline.init_notebook_mode()
app = Dash(__name__)
data = read_xyz()
#print(data)
if (len(view) == 0):
set_view = get_preset_view()
else:
set_view = view
data = xyz_reader.read_xyz(datapath_or_datastring=data, is_datafile=False)
app.layout = html.Div([
dcc.Dropdown(
id='default-speck-preset-views',
options=[
{'label': 'Ball and stick', 'value': 'stickball'},
{'label': 'Default', 'value': 'default'}
],
value='stickball'
),
dashbio.Speck(
id='default-speck',
data=data,
view=set_view
#presetView = 'stickball'
),
])
#print(data)
#app.run(debug=True)
#app.run(jupyter_mode="external",port=1234)
return(app)
#try:
# @callback(
# Output('default-speck', 'presetView'),
# Input('default-speck-preset-views', 'value##'))
#except NameError:
# pass
[docs]
def update_preset_view(preset_name):
return preset_name
[docs]
def display_crystal(data, view=None, bond_length=3.5):
if view is None:
view = {}
"""
Displays a crystal structure using the provided data and visualization parameters.
Args:
data (dict): The crystal structure data, typically in the form of an atomic
structure object or a compatible data format.
view (dict, optional): A dictionary specifying the viewing parameters, such as
camera angle, zoom level, or rendering style. Defaults to an empty dictionary.
bond_length (float, optional): The maximum bond length threshold for visualizing
atomic bonds. Defaults to 3.5.
Returns:
None: Displays the crystal structure visualization.
Example:
>>> display_crystal(my_crystal_data, view={'angle': 45}, bond_length=4.0)
"""
# Take the data
atomic_positions = data['positions']
mat = data['mat']
connections,_ = get_connections(atomic_positions[:,0],atomic_positions[:,1],atomic_positions[:,2], bond_length)
#print(connections)
bond_color = (0.8,0.8,0)
if ('bond_color' in view.keys()):
bond_color = view['bond_color']
bond_radius = 0.15
if ('bond_radius' in view.keys()):
bond_radius = view['bond_radius']
if ('verbosity' in data.keys()):
verbosity = data['verbosity']
# Initializing mayavi with a backend
if ('in_notebook' in view.keys()):
if('backend' in view.keys()):
mlab.init_notebook(backend = view['backend'])
else:
mlab.init_notebook()
# Initialize figure
mlab.figure('EPWpy',bgcolor=(1,1,1),size=(800, 600))
# Plot atomic positions with jmol colors and atomic radius-based size
length = len(np.array(atomic_positions))
X = []
Y = []
Z = []
for i,atom in enumerate(atomic_positions):
x, y, z = atom
# Get atomic number
atom_number = atomic_number.get(mat[i], 12)
#print(atom)
# Get jmol color for the atomic number
color = jmol_colors.get(atom_number, (0.5, 0.5, 0.5)) # Default to gray if not in map
# Get atomic radius (if available) and adjust the size accordingly
radius = atomic_radii.get(atom_number, 50*pm2bohr) # Default to 50 pm if not in the map
size = (radius / 2) # Scale size for visibility
# Print atomic info for the user
#if (verbosity > 2):
# print(f"Atom: {atom_number}, Position: ({x}, {y}, {z}), Radius: {radius} bohr")
X.append(x)
Y.append(y)
Z.append(z)
src = mlab.points3d(x, y, z, color=color, scale_factor=size)
src = mlab.points3d(X, Y, Z, color=(0,0,0), scale_factor=0.0)#size)
src.mlab_source.dataset.lines = np.array(connections)
tube = mlab.pipeline.tube(src, tube_radius = bond_radius)
tube.filter.radius_factor = 1.
mlab.pipeline.surface(tube, color= bond_color)
if ('in_notebook' in view.keys()):
return(src)#mlab.points3d(x, y, z, color=color, scale_factor=size))
else:
mlab.show()
[docs]
def display_crystal_matplotlib(
data,
view=None,
point_size=100,
wrap=False,
show_labels=False,
bond_length=None,
scale=1.20,
min_dist=0.10,
depthshade=False,
bond_size=3.0,
show_bond_length=False,
grid=False,
show=False
):
if view is None:
view = {}
#title = data['title']
lattice = data['lattice_vec']
elems = data['atomic_species']
counts = data['natoms']
coord_type = 'cartesian'
positions = data['positions']
tags = data['mat']
atom_size_factor = point_size
if ('grid' in view.keys()):
grid = view['grid']
if ('show_labels' in view.keys()):
show_labels = view['show_labels']
if ('bond_size' in view.keys()):
bond_size = view['bond_size']
if ('depthshade' in view.keys()):
depthshade = view['depthshade']
atom_size_factor = 1000
if ('atom_size_factor' in view.keys()):
atom_size_factor = view['atom_size_factor']
bond_color = (0.8,0.8,0)
if ('bond_color' in view.keys()):
bond_color = view['bond_color']
if ('wrap' in view.keys()):
wrap = view['wrap']
if ('scale' in view.keys()):
scale = view['scale']
if ('show_bond_length' in view.keys()):
show_bond_length = view['show_bond_length']
if ('show_labels' in view.keys()):
show_labels = view['show_labels']
if ('bond_length' in view.keys()):
bond_length = view['bond_length']
if ('min_dist' in view.keys()):
bond_length = view['min_dist']
# ensure fractional positions if needed (for wrapping), then cartesian for plotting
if coord_type.lower().startswith("d"):
fracs = [wrap_frac(p) if wrap else p[:] for p in positions]
carts = [frac_to_cart(fr, lattice) for fr in fracs]
else:
carts = [p[:] for p in positions] # already cartesian
if wrap:
# Convert to frac, wrap, back to cart using your Cramer's rule
A = [[lattice[0][0], lattice[1][0], lattice[2][0]],
[lattice[0][1], lattice[1][1], lattice[2][1]],
[lattice[0][2], lattice[1][2], lattice[2][2]]]
detA = mat_det_3(A)
if abs(detA) < 1e-12:
raise ValueError("Singular lattice for wrapping.")
def invA_mul(r):
(a,b,c),(d,e,f),(g,h,i) = A
adj = [
[ e*i - f*h, c*h - b*i, b*f - c*e],
[ f*g - d*i, a*i - c*g, c*d - a*f],
[ d*h - e*g, b*g - a*h, a*e - b*d],
]
num = [adj[0][0]*r[0] + adj[0][1]*r[1] + adj[0][2]*r[2],
adj[1][0]*r[0] + adj[1][1]*r[1] + adj[1][2]*r[2],
adj[2][0]*r[0] + adj[2][1]*r[1] + adj[2][2]*r[2]]
return [num[0]/detA, num[1]/detA, num[2]/detA]
fracs = [wrap_frac(invA_mul(r)) for r in carts]
carts = [frac_to_cart(fr, lattice) for fr in fracs]
# If fracs not defined (cartesian input with no wrap), compute fracs for bonding anyway
if coord_type.lower().startswith("d"):
fracs_for_bond = fracs
else:
# compute fractional coords for cart positions (not wrapped)
fracs_for_bond = [cart_to_frac(r, lattice) for r in carts]
# Detect bonds (minimum-image)
bonds, distance = get_connections(positions[:,0],positions[:,1],positions[:,2], bond_length, min_dist)
#detect_bonds(fracs_for_bond, tags, lattice, bond_length=bond_length, scale=scale, min_dist=min_dist)
# Collect points for limits: cell vertices + atoms
verts = unit_cell_vertices(lattice)
all_pts = verts + carts
# Prepare plot
fig = plt.figure(figsize=(9, 8))
ax = fig.add_subplot(111, projection='3d')
#ax.set_title(title)
# Draw cell
draw_unit_cell(ax, lattice, lw=1.5)
# Group atoms by element for legend/color
by_elem = {}
for pos, el in zip(carts, tags):
by_elem.setdefault(el, []).append(pos)
# Plot atoms
for i, el in enumerate(elems):
pts = by_elem.get(el, [])
if not pts:
continue
xs = [p[0] for p in pts]
ys = [p[1] for p in pts]
zs = [p[2] for p in pts]
atom_number = atomic_number.get(el, 12)
#print(atom)
# Get jmol color for the atomic number
color = jmol_colors.get(atom_number, (0.5, 0.5, 0.5))
size = atomic_radii.get(atom_number, 0.8) * atom_size_factor
ax.scatter(xs, ys, zs, s=size, depthshade = depthshade, label=f"{el} ({len(pts)})", marker='o', color = color)
if show_labels and len(pts) <= 50:
for p in pts:
ax.text(p[0], p[1], p[2], el, fontsize=8, ha='center', va='center')
# Draw bonds: thin gray for all, red thick for Si-Si (if highlight_si)
# for (i, j, dx, dy, dz, dist) in bonds:
t = 0
for (i, j) in bonds:
p_i = carts[i]
# compute image of j in Cartesian: j + dx*a + dy*b + dz*c
a, b, c = lattice[0], lattice[1], lattice[2]
p_j = carts[j] #add(carts[j], [dx*a[0] + dy*b[0] + dz*c[0],
# dx*a[1] + dy*b[1] + dz*c[1],
# dx*a[2] + dy*b[2] + dz*c[2]])
# draw base bond in light gray
#if (np.linalg.norm([dx,dy,dz]) == 0.0):
#print('dist',dist,p_i,p_j)
dist = distance[t]
t +=1
ax.plot([p_i[0], p_j[0]], [p_i[1], p_j[1]], [p_i[2], p_j[2]],
linewidth=bond_size, color=bond_color, zorder=0)
if (show_bond_length):
xm = 0.5*(p_i[0] + p_j[0]); ym = 0.5*(p_i[1] + p_j[1]); zm = 0.5*(p_i[2] + p_j[2])
ax.text(xm, ym, zm, f"{dist:.3f}Å", color='red', fontsize=7, ha='center', va='center')
# highlight Si-Si special
"""
if highlight_si and tags[i].strip().lower() == 'si' and tags[j].strip().lower() == 'si':
ax.plot([p_i[0], p_j[0]], [p_i[1], p_j[1]], [p_i[2], p_j[2]],
linewidth=2.6, color='red', zorder=1)
# label midpoint with distance
xm = 0.5*(p_i[0] + p_j[0]); ym = 0.5*(p_i[1] + p_j[1]); zm = 0.5*(p_i[2] + p_j[2])
ax.text(xm, ym, zm, f"{dist:.3f}Å", color='red', fontsize=7, ha='center', va='center')
# write bonds to file
#with open("bonds.txt", "w", encoding="utf-8") as f:
# f.write("# i j el_i el_j dx dy dz distance_A\n")
# for (i, j, dx, dy, dz, dist) in bonds:
# f.write(f"{i+1} {j+1} {tags[i]} {tags[j]} {dx} {dy} {dz} {dist:.6f}\n")
#print(f"Wrote {len(bonds)} bonds to bonds.txt")
"""
ax.grid(grid)
auto_equal_limits(ax, all_pts)
ax.set_xlabel("x (Å)")
ax.set_ylabel("y (Å)")
ax.set_zlabel("z (Å)")
ax.legend(loc="upper left", bbox_to_anchor=(0.02, 0.98))
plt.tight_layout()
if (show):
plt.show()
else:
return(ax)
[docs]
def display_crystal_phonon(
data,
view=None,
point_size=120,
wrap=True,
show_labels=True,
bond_length=None,
frame_interval=10,
scale=1.20,
min_dist=0.10,
show_bond_length=True
):
if view is None:
view = {}
#title = data['title']
lattice = data['lattice_vec']
elems = data['atomic_species']
counts = data['natoms']
coord_type = 'cartesian'
positions = data['positions']
tags = data['mat']
forces = data['forces']
quiver_length = 15.0
atom_size_factor = 100
if ('show_labels' in view.keys()):
show_labels = view['show_labels']
if ('atom_size_factor' in view.keys()):
atom_size_factor = view['atom_size_factor']
bond_color = (0.8,0.8,0)
if ('bond_color' in view.keys()):
bond_color = view['bond_color']
if ('wrap' in view.keys()):
wrap = view['wrap']
if ('scale' in view.keys()):
scale = view['scale']
if ('show_bond_length' in view.keys()):
show_bond_length = view['show_bond_length']
if ('show_labels' in view.keys()):
show_labels = view['show_labels']
if ('bond_length' in view.keys()):
bond_length = view['bond_length']
if ('min_dist' in view.keys()):
bond_length = view['min_dist']
if ('quiver_length' in view.keys()):
quiver_length = view['quiver_length']
# ensure fractional positions if needed (for wrapping), then cartesian for plotting
if coord_type.lower().startswith("d"):
fracs = [wrap_frac(p) if wrap else p[:] for p in positions]
carts = [frac_to_cart(fr, lattice) for fr in fracs]
else:
carts = [p[:] for p in positions] # already cartesian
if wrap:
# Convert to frac, wrap, back to cart using your Cramer's rule
A = [[lattice[0][0], lattice[1][0], lattice[2][0]],
[lattice[0][1], lattice[1][1], lattice[2][1]],
[lattice[0][2], lattice[1][2], lattice[2][2]]]
detA = mat_det_3(A)
if abs(detA) < 1e-12:
raise ValueError("Singular lattice for wrapping.")
def invA_mul(r):
(a,b,c),(d,e,f),(g,h,i) = A
adj = [
[ e*i - f*h, c*h - b*i, b*f - c*e],
[ f*g - d*i, a*i - c*g, c*d - a*f],
[ d*h - e*g, b*g - a*h, a*e - b*d],
]
num = [adj[0][0]*r[0] + adj[0][1]*r[1] + adj[0][2]*r[2],
adj[1][0]*r[0] + adj[1][1]*r[1] + adj[1][2]*r[2],
adj[2][0]*r[0] + adj[2][1]*r[1] + adj[2][2]*r[2]]
return [num[0]/detA, num[1]/detA, num[2]/detA]
fracs = [wrap_frac(invA_mul(r)) for r in carts]
carts = [frac_to_cart(fr, lattice) for fr in fracs]
# If fracs not defined (cartesian input with no wrap), compute fracs for bonding anyway
if coord_type.lower().startswith("d"):
fracs_for_bond = fracs
else:
# compute fractional coords for cart positions (not wrapped)
fracs_for_bond = [cart_to_frac(r, lattice) for r in carts]
# Detect bonds (minimum-image)
bonds, distance = get_connections(positions[:,0],positions[:,1],positions[:,2], bond_length, min_dist)
#detect_bonds(fracs_for_bond, tags, lattice, bond_length=bond_length, scale=scale, min_dist=min_dist)
# Collect points for limits: cell vertices + atoms
verts = unit_cell_vertices(lattice)
all_pts = verts + carts
# Prepare plot
fig = plt.figure(figsize=(9, 8))
ax = fig.add_subplot(111, projection='3d')
#ax.set_title(title)
# Draw cell
draw_unit_cell(ax, lattice, lw=1.5)
# Group atoms by element for legend/color
by_elem = {}
for pos, el in zip(carts, tags):
by_elem.setdefault(el, []).append(pos)
# Plot atoms
for i, el in enumerate(elems):
pts = by_elem.get(el, [])
if not pts:
continue
xs = [p[0] for p in pts]
ys = [p[1] for p in pts]
zs = [p[2] for p in pts]
atom_number = atomic_number.get(el, 12)
#print(atom)
# Get jmol color for the atomic number
color = jmol_colors.get(atom_number, (0.5, 0.5, 0.5))
size = atomic_radii.get(atom_number, 0.8) * atom_size_factor
ax.scatter(xs, ys, zs, s=size, depthshade=True, label=f"{el} ({len(pts)})", marker='o',color = color)
if show_labels and len(pts) <= 50:
for p in pts:
ax.text(p[0], p[1], p[2], el, fontsize=8, ha='center', va='center')
ax.quiver(positions[:,0],positions[:,1],positions[:,2],forces[:,0],forces[:,1],forces[:,2],length = quiver_length)
# Draw bonds: thin gray for all, red thick for Si-Si (if highlight_si)
# for (i, j, dx, dy, dz, dist) in bonds:
t = 0
for (i, j) in bonds:
p_i = carts[i]
# compute image of j in Cartesian: j + dx*a + dy*b + dz*c
a, b, c = lattice[0], lattice[1], lattice[2]
p_j = carts[j] #add(carts[j], [dx*a[0] + dy*b[0] + dz*c[0],
# dx*a[1] + dy*b[1] + dz*c[1],
# dx*a[2] + dy*b[2] + dz*c[2]])
# draw base bond in light gray
#if (np.linalg.norm([dx,dy,dz]) == 0.0):
#print('dist',dist,p_i,p_j)
dist = distance[t]
t +=1
ax.plot([p_i[0], p_j[0]], [p_i[1], p_j[1]], [p_i[2], p_j[2]],
linewidth=3.0, color=bond_color, zorder=0)
if (show_bond_length):
xm = 0.5*(p_i[0] + p_j[0]); ym = 0.5*(p_i[1] + p_j[1]); zm = 0.5*(p_i[2] + p_j[2])
ax.text(xm, ym, zm, f"{dist:.3f}Å", color='red', fontsize=7, ha='center', va='center')
auto_equal_limits(ax, all_pts)
ax.set_xlabel("x (Å)")
ax.set_ylabel("y (Å)")
ax.set_zlabel("z (Å)")
ax.legend(loc="upper left", bbox_to_anchor=(0.02, 0.98))
#plt.tight_layout()
#plt.show()
return(ax)
[docs]
def animate_crystal_phonon(
data,
view=None,
alpha=2.0,
point_size=120,
wrap=True,
show_labels=True,
bond_length=None,
scale=1.20,
frame_interval=10,
min_dist=0.10,
show_bond_length=True
):
if view is None:
view = {}
#title = data['title']
lattice = data['lattice_vec']
elems = data['atomic_species']
counts = data['natoms']
coord_type = 'cartesian'
positions = data['positions']
tags = data['mat']
forces = data['forces']
tot_points = 20
quiver_length = 15.0
atom_size_factor = 100
if ('tot_points' in view.keys()):
tot_points = view['tot_points']
new_positions = np.zeros((tot_points, len(positions[:,0]),len(positions[0,:])),dtype = float)
new_positions[0,:,:] = positions[:,:]
sine_f = np.sin(np.linspace(0,6.28,tot_points))
q_length = []
q_length.append(0.0)
if ('quiver_length' in view.keys()):
quiver_length = view['quiver_length']
for i in range(tot_points-1):
new_positions[i+1,:,:] = positions[:,:] + forces[:,:]*alpha*sine_f[i+1]
q_length.append(quiver_length*sine_f[i+1])
if ('atom_size_factor' in view.keys()):
atom_size_factor = view['atom_size_factor']
interval = 10
if ('interval' in view.keys()):
interval = view['interval']
if ('show_labels' in view.keys()):
show_labels = view['show_labels']
bond_color = (0.8,0.8,0)
if ('bond_color' in view.keys()):
bond_color = view['bond_color']
if ('wrap' in view.keys()):
wrap = view['wrap']
if ('scale' in view.keys()):
scale = view['scale']
if ('show_bond_length' in view.keys()):
show_bond_length = view['show_bond_length']
if ('show_labels' in view.keys()):
show_labels = view['show_labels']
if ('bond_length' in view.keys()):
bond_length = view['bond_length']
if ('min_dist' in view.keys()):
bond_length = view['min_dist']
bold_lines = []
# ensure fractional positions if needed (for wrapping), then cartesian for plotting
if coord_type.lower().startswith("d"):
fracs = [wrap_frac(p) if wrap else p[:] for p in positions]
carts = [frac_to_cart(fr, lattice) for fr in fracs]
else:
carts = [p[:] for p in positions] # already cartesian
if wrap:
# Convert to frac, wrap, back to cart using your Cramer's rule
A = [[lattice[0][0], lattice[1][0], lattice[2][0]],
[lattice[0][1], lattice[1][1], lattice[2][1]],
[lattice[0][2], lattice[1][2], lattice[2][2]]]
detA = mat_det_3(A)
if abs(detA) < 1e-12:
raise ValueError("Singular lattice for wrapping.")
def invA_mul(r):
(a,b,c),(d,e,f),(g,h,i) = A
adj = [
[ e*i - f*h, c*h - b*i, b*f - c*e],
[ f*g - d*i, a*i - c*g, c*d - a*f],
[ d*h - e*g, b*g - a*h, a*e - b*d],
]
num = [adj[0][0]*r[0] + adj[0][1]*r[1] + adj[0][2]*r[2],
adj[1][0]*r[0] + adj[1][1]*r[1] + adj[1][2]*r[2],
adj[2][0]*r[0] + adj[2][1]*r[1] + adj[2][2]*r[2]]
return [num[0]/detA, num[1]/detA, num[2]/detA]
fracs = [wrap_frac(invA_mul(r)) for r in carts]
carts = [frac_to_cart(fr, lattice) for fr in fracs]
# If fracs not defined (cartesian input with no wrap), compute fracs for bonding anyway
if coord_type.lower().startswith("d"):
fracs_for_bond = fracs
else:
# compute fractional coords for cart positions (not wrapped)
fracs_for_bond = [cart_to_frac(r, lattice) for r in carts]
# Detect bonds (minimum-image)
bonds, distance = get_connections(positions[:,0],positions[:,1],positions[:,2], bond_length, min_dist)
#detect_bonds(fracs_for_bond, tags, lattice, bond_length=bond_length, scale=scale, min_dist=min_dist)
# Collect points for limits: cell vertices + atoms
verts = unit_cell_vertices(lattice)
all_pts = verts + carts
# Prepare plot
fig = plt.figure(figsize=(9, 8))
ax = fig.add_subplot(111, projection='3d')
#ax.set_title(title)
# Draw cell
draw_unit_cell(ax, lattice, lw=1.5)
# Group atoms by element for legend/color
by_elem = {}
for pos, el in zip(carts, tags):
by_elem.setdefault(el, []).append(pos)
# Plot atoms
scatter = []
for i, el in enumerate(elems):
pts = by_elem.get(el, [])
if not pts:
continue
xs = [p[0] for p in pts]
ys = [p[1] for p in pts]
zs = [p[2] for p in pts]
atom_number = atomic_number.get(el, 12)
#print(atom)
# Get jmol color for the atomic number
color = jmol_colors.get(atom_number, (0.5, 0.5, 0.5))
size = atomic_radii.get(atom_number, 0.8) * atom_size_factor
scatts = ax.scatter(xs, ys, zs, s=size, depthshade=True, label=f"{el} ({len(pts)})", marker='o',color = color)
scatter.append(scatts)
if show_labels and len(pts) <= 50:
for p in pts:
ax.text(p[0], p[1], p[2], el, fontsize=8, ha='center', va='center')
quiv = ax.quiver(positions[:,0],positions[:,1],positions[:,2],forces[:,0],forces[:,1],forces[:,2],length = q_length[0])
t = 0
for (i, j) in bonds:
p_i = carts[i]
# compute image of j in Cartesian: j + dx*a + dy*b + dz*c
a, b, c = lattice[0], lattice[1], lattice[2]
p_j = carts[j] #add(carts[j], [dx*a[0] + dy*b[0] + dz*c[0],
# dx*a[1] + dy*b[1] + dz*c[1],
# dx*a[2] + dy*b[2] + dz*c[2]])
# draw base bond in light gray
#if (np.linalg.norm([dx,dy,dz]) == 0.0):
#print('dist',dist,p_i,p_j)
dist = distance[t]
t +=1
line, = ax.plot([p_i[0], p_j[0]], [p_i[1], p_j[1]], [p_i[2], p_j[2]],
linewidth=3.0, color=bond_color, zorder=0)
#line.append(lines)
bold_lines.append(line)
if (show_bond_length):
xm = 0.5*(p_i[0] + p_j[0]); ym = 0.5*(p_i[1] + p_j[1]); zm = 0.5*(p_i[2] + p_j[2])
ax.text(xm, ym, zm, f"{dist:.3f}Å", color='red', fontsize=7, ha='center', va='center')
def update(frame):
nonlocal quiv, bold_lines
# print('frame',frame)
# print(positions,new_positions)
positions[:,:] = new_positions[int(frame),:,:]
if coord_type.lower().startswith("d"):
fracs = [wrap_frac(p) if wrap else p[:] for p in positions]
carts = [frac_to_cart(fr, lattice) for fr in fracs]
else:
carts = [p[:] for p in positions] # already cartesian
if wrap:
# Convert to frac, wrap, back to cart using your Cramer's rule
A = [[lattice[0][0], lattice[1][0], lattice[2][0]],
[lattice[0][1], lattice[1][1], lattice[2][1]],
[lattice[0][2], lattice[1][2], lattice[2][2]]]
detA = mat_det_3(A)
if abs(detA) < 1e-12:
raise ValueError("Singular lattice for wrapping.")
def invA_mul(r):
(a,b,c),(d,e,f),(g,h,i) = A
adj = [
[ e*i - f*h, c*h - b*i, b*f - c*e],
[ f*g - d*i, a*i - c*g, c*d - a*f],
[ d*h - e*g, b*g - a*h, a*e - b*d],
]
num = [adj[0][0]*r[0] + adj[0][1]*r[1] + adj[0][2]*r[2],
adj[1][0]*r[0] + adj[1][1]*r[1] + adj[1][2]*r[2],
adj[2][0]*r[0] + adj[2][1]*r[1] + adj[2][2]*r[2]]
return [num[0]/detA, num[1]/detA, num[2]/detA]
fracs = [wrap_frac(invA_mul(r)) for r in carts]
carts = [frac_to_cart(fr, lattice) for fr in fracs]
# If fracs not defined (cartesian input with no wrap), compute fracs for bonding anyway
if coord_type.lower().startswith("d"):
fracs_for_bond = fracs
else:
# compute fractional coords for cart positions (not wrapped)
fracs_for_bond = [cart_to_frac(r, lattice) for r in carts]
bonds, distance = get_connections(positions[:,0],positions[:,1],positions[:,2], bond_length, min_dist)
# Collect points for limits: cell vertices + atoms
verts = unit_cell_vertices(lattice)
all_pts = verts + carts
# Group atoms by element for legend/color
by_elem = {}
for pos, el in zip(carts, tags):
by_elem.setdefault(el, []).append(pos)
# Plot atoms
for i, el in enumerate(elems):
pts = by_elem.get(el, [])
if not pts:
continue
xs = [p[0] for p in pts]
ys = [p[1] for p in pts]
zs = [p[2] for p in pts]
atom_number = atomic_number.get(el, 12)
#print(atom)
# Get jmol color for the atomic number
color = jmol_colors.get(atom_number, (0.5, 0.5, 0.5))
scatter[i]._offsets3d = (xs, ys, zs)
if show_labels and len(pts) <= 50:
for p in pts:
ax.text(p[0], p[1], p[2], el, fontsize=8, ha='center', va='center')
if quiv:
quiv.remove()
quiv = ax.quiver(positions[:,0],positions[:,1],positions[:,2],forces[:,0],forces[:,1],forces[:,2],length = q_length[int(frame)])
# Draw bonds: thin gray for all, red thick for Si-Si (if highlight_si)
# # for (i, j, dx, dy, dz, dist) in bonds:
t = 0
for line in bold_lines:
line.remove()
bold_lines = []
for (i, j) in bonds:
p_i = carts[i]
# compute image of j in Cartesian: j + dx*a + dy*b + dz*c
a, b, c = lattice[0], lattice[1], lattice[2]
p_j = carts[j] #add(carts[j], [dx*a[0] + dy*b[0] + dz*c[0],
# dx*a[1] + dy*b[1] + dz*c[1],
# dx*a[2] + dy*b[2] + dz*c[2]])
# draw base bond in light gray
#if (np.linalg.norm([dx,dy,dz]) == 0.0):
#print('dist',dist,p_i,p_j)
dist = distance[t]
t +=1
line, = ax.plot([p_i[0], p_j[0]], [p_i[1], p_j[1]], [p_i[2], p_j[2]],
linewidth=3.0, color=bond_color, zorder=0)
bold_lines.append(line)
if (show_bond_length):
xm = 0.5*(p_i[0] + p_j[0]); ym = 0.5*(p_i[1] + p_j[1]); zm = 0.5*(p_i[2] + p_j[2])
ax.text(xm, ym, zm, f"{dist:.3f}Å", color='red', fontsize=7, ha='center', va='center')
return scatter,
auto_equal_limits(ax, all_pts)
ax.set_xlabel("x (Å)")
ax.set_ylabel("y (Å)")
ax.set_zlabel("z (Å)")
ax.legend(loc="upper left", bbox_to_anchor=(0.02, 0.98))
#plt.tight_layout()
#plt.show()
ani = animation.FuncAnimation(fig, update, frames=tot_points, interval=frame_interval, blit=False)
ani.save("crystal_phonon_with_bonds.gif", writer="pillow")
print("Animated crystal phonon visualization with bonds saved as crystal_phonon")
# plt.show()
return(fig, ani)
# return(ax)
if __name__ == '__main__':
#app = Dash(__name__)
draw_molecule()
#app.run(debug=True)