from typing import (
List,
Optional,
Tuple,
Union,
TYPE_CHECKING,
Dict,
Any,
Literal,
cast,
Sequence,
)
import numpy as np
from scm.plams.core.errors import MissingOptionalPackageError
from scm.plams.core.functions import requires_optional_package
from scm.plams.interfaces.adfsuite.ams import AMSJob
from scm.plams.mol.molecule import Molecule
try:
from scm.base import ChemicalSystem
_has_scm_chemsys = True
except ImportError:
_has_scm_chemsys = False
if TYPE_CHECKING:
import matplotlib.pyplot as plt
import ase
from os import PathLike
from PIL import Image as PilImage
from scm.plams.recipes.md.trajectoryanalysis import AMSMSDJob
__all__ = [
"linear_fit_extrapolate_to_0",
"plot_band_structure",
"plot_phonons_band_structure",
"plot_phonons_dos",
"plot_phonons_thermodynamic_properties",
"plot_molecule",
"plot_image_grid",
"plot_correlation",
"plot_msd",
"plot_work_function",
"plot_grid_molecules",
]
[docs]@requires_optional_package("matplotlib")
def plot_band_structure(
x: List[float],
y_spin_up: np.ndarray,
y_spin_down: Optional[np.ndarray] = None,
labels: Optional[List[str]] = None,
fermi_energy: Optional[float] = None,
zero: Optional[Union[Literal["fermi", "vbm", "vbmax", "cbm", "cbmin"], float]] = None,
ax: Optional["plt.Axes"] = None,
) -> "plt.Axes":
"""
Plots an electronic band structure from DFTB, BAND, or QuantumEspresso engines with matplotlib.
To control the appearance of the plot you need to call ``plt.ylim(bottom, top)``, ``plt.title(title)``, etc.
manually outside this function.
x: list of float
Returned by AMSResults.get_band_structure()
y_spin_up: 2D numpy array of float
Returned by AMSResults.get_band_structure()
y_spin_down: 2D numpy array of float. If None, the spin down bands are not plotted.
Returned by AMSResults.get_band_structure()
labels: list of str
Returned by AMSResults.get_band_structure()
fermi_energy: float
Returned by AMSResults.get_band_structure(). Should have the same unit as ``y``.
zero: None or float or one of 'fermi', 'vbm', 'vbmax', 'cbm', 'cbmin'
Shift the curves so that y=0 is at the specified value. If None, no shift is performed. 'fermi', 'vbmax', and 'cbmin' require that the ``fermi_energy`` is not None. Note: 'vbmax' and 'cbmin' calculate the zero as the highest (lowest) eigenvalue smaller (greater) than or equal to ``fermi_energy``. This is NOT necessarily equal to the valence band maximum or conduction band minimum as calculated by the compute engine.
Additional parameters:
``ax``: matplotlib axis
The axis. If None, one will be created
"""
import matplotlib.pyplot as plt
if zero is None:
zero_value = 0.0
elif zero == "fermi":
assert fermi_energy is not None
zero_value = fermi_energy
elif zero in ["vbm", "vbmax"]:
assert fermi_energy is not None
zero_value = y_spin_up[y_spin_up <= fermi_energy].max()
if y_spin_down is not None:
zero_value = max(zero_value, y_spin_down[y_spin_down <= fermi_energy].max())
elif zero in ["cbm", "cbmin"]:
assert fermi_energy is not None
zero_value = y_spin_up[y_spin_up >= fermi_energy].min()
if y_spin_down is not None:
zero_value = min(zero_value, y_spin_down[y_spin_down <= fermi_energy].min())
else:
raise ValueError(
f"When specified, zero must be a float or one of: 'fermi', 'vbm', 'vbmax', 'cbm', 'cbmin'; but was '{zero}'"
)
labels = labels or []
for i, label in enumerate(labels):
if label:
label = (
label.replace("GAMMA", "\\Gamma")
.replace("DELTA", "\\Delta")
.replace("LAMBDA", "\\Lambda")
.replace("SIGMA", "\\Sigma")
)
labels[i] = f"${label}$"
if ax is None:
_, ax = plt.subplots()
ax.plot(x, y_spin_up - zero_value, "-")
if y_spin_down is not None:
ax.plot(x, y_spin_down - zero_value, "--")
tick_x: List[float] = []
tick_labels: List[str] = []
for xx, ll in zip(x, labels):
if ll:
if len(tick_x) == 0:
tick_x.append(xx)
tick_labels.append(ll)
continue
if np.isclose(xx, tick_x[-1]):
if ll != tick_labels[-1]:
tick_labels[-1] += f",{ll}"
else:
tick_x.append(xx)
tick_labels.append(ll)
for xx in tick_x:
ax.axvline(xx)
if fermi_energy is not None:
ax.axhline(fermi_energy - zero_value, linestyle="--")
ax.set_xticks(ticks=tick_x, labels=tick_labels)
return ax
[docs]@requires_optional_package("matplotlib")
def plot_phonons_band_structure(
x: List[float],
y: np.ndarray,
labels: Optional[List[str]] = None,
zero: Optional[float] = None,
ax: Optional["plt.Axes"] = None,
) -> "plt.Axes":
"""
Plots a phonons band structure from DFTB, BAND or QuantumEspresso engines with matplotlib.
To control the appearance of the plot you need to call ``plt.ylim(bottom, top)``, ``plt.title(title)``, etc.
manually outside this function.
x: list of float
Returned by AMSResults.get_phonons_band_structure()
y: 2D numpy array of float
Returned by AMSResults.get_phonons_band_structure()
labels: list of str
Returned by AMSResults.get_phonons_band_structure()
zero: None or float
Shift the curves so that y=0 is at the specified value. If None, no shift is performed.
Additional parameters:
``ax``: matplotlib axis
The axis. If None, one will be created
"""
import matplotlib.pyplot as plt
if zero is None:
zero = 0
labels = labels or []
for i, label in enumerate(labels):
if label:
label = (
label.replace("GAMMA", "\\Gamma")
.replace("DELTA", "\\Delta")
.replace("LAMBDA", "\\Lambda")
.replace("SIGMA", "\\Sigma")
)
labels[i] = f"${label}$"
if ax is None:
_, ax = plt.subplots()
ax.plot(x, y - zero, "-")
tick_x: List[float] = []
tick_labels: List[str] = []
for xx, ll in zip(x, labels):
if ll:
if len(tick_x) == 0:
tick_x.append(xx)
tick_labels.append(ll)
continue
if np.isclose(xx, tick_x[-1]):
if ll != tick_labels[-1]:
tick_labels[-1] += f",{ll}"
else:
tick_x.append(xx)
tick_labels.append(ll)
for xx in tick_x:
ax.axvline(xx, dashes=[2, 2], color="gray")
ax.set_xticks(ticks=tick_x, labels=tick_labels)
return ax
[docs]@requires_optional_package("matplotlib")
def plot_phonons_dos(
energy: List[float],
total_dos: List[float],
dos_per_species: Dict[str, List[float]],
dos_per_atom: Dict[str, List[float]],
dos_type: str = "total",
ax: Optional["plt.Axes"] = None,
) -> "plt.Axes":
"""
Plots the phonons DOS from DFTB, BAND or QuantumEspresso engines with matplotlib.
To control the appearance of the plot you need to call ``plt.ylim(bottom, top)``, ``plt.title(title)``, etc.
manually outside this function.
energy: list of float
Returned by AMSResults.get_phonons_thermodynamic_properties()
total_dos: list of float
Returned by AMSResults.get_phonons_thermodynamic_properties()
dos_per_species: dictionary of list of float
Returned by AMSResults.get_phonons_thermodynamic_properties()
dos_per_atom: dictionary of list of float
Returned by AMSResults.get_phonons_thermodynamic_properties()
dos_type: str
Specifies the kind of plot to show. Possible options:
- "total": Total DOS.
- "species": DOS decomposed by species.
- "atom": DOS decomposed by atom.
Additional parameters:
``ax``: matplotlib axis
The axis. If None, one will be created
"""
import matplotlib.pyplot as plt
if ax is None:
_, ax = plt.subplots()
if dos_type == "total":
ax.plot(energy, total_dos, color="black", label="Total DOS", linestyle="-", zorder=1)
elif dos_type == "species":
ax.plot(
energy,
total_dos,
color="black",
label="Total DOS",
linestyle="-",
zorder=-1,
)
for i, (l, v) in enumerate(dos_per_species.items()):
ax.plot(energy, v, label=f"pDOS {l}", dashes=[3, i + 1, 2], zorder=i)
elif dos_type == "atoms":
ax.plot(
energy,
total_dos,
color="black",
label="Total DOS",
linestyle="-",
zorder=-1,
)
for i, (l, v) in enumerate(dos_per_atom.items()):
ax.plot(energy, v, label=f"pDOS {l}", dashes=[3, i + 1, 2], zorder=i)
else:
raise ValueError("Invalid dos_type. Must be 'total', 'species', or 'atom'.")
plt.legend()
return ax
[docs]@requires_optional_package("matplotlib")
def plot_phonons_thermodynamic_properties(
temperature: List[float],
properties: Dict[str, List[float]],
units: Dict[str, str],
ax: Optional["plt.Axes"] = None,
) -> "plt.Axes":
"""
Plots the phonons thermodynamic properties from DFTB, BAND or QuantumEspresso engines with matplotlib.
To control the appearance of the plot you need to call ``plt.ylim(bottom, top)``, ``plt.title(title)``, etc.
manually outside this function.
temperature: list of float
Returned by AMSResults.get_phonons_thermodynamic_properties()
properties: dictionary of list of float
Returned by AMSResults.get_phonons_thermodynamic_properties()
units: dictionary of str
Returned by AMSResults.get_phonons_thermodynamic_properties()
Additional parameters:
``ax``: matplotlib axis
The axis. If None, one will be created
"""
import matplotlib.pyplot as plt
if ax is None:
_, ax = plt.subplots()
for i, (label, prop) in enumerate(properties.items()):
ax.plot(
temperature,
prop,
label=label + " (" + units[label] + ")",
linestyle="-",
lw=2,
zorder=1,
)
plt.legend()
return ax
[docs]@requires_optional_package("matplotlib")
@requires_optional_package("ase")
def plot_molecule(
molecule: Union["ase.Atoms", Molecule],
figsize: Optional[Union[int, Tuple[int]]] = None,
ax: Optional["plt.Axes"] = None,
keep_axis: bool = False,
**kwargs: Any,
) -> "plt.Axes":
"""Show a molecule in a Jupyter notebook"""
import matplotlib.pyplot as plt
from ase.visualize.plot import plot_atoms
from scm.plams.interfaces.molecule.ase import toASE
if isinstance(molecule, Molecule):
molecule = toASE(molecule)
if ax is None:
_, ax = plt.subplots(figsize=figsize or (2, 2))
plot_atoms(molecule, ax=ax, **kwargs)
if not keep_axis:
ax.axis("off")
return ax
[docs]@requires_optional_package("rdkit")
def plot_grid_molecules(
molecules: List[Union[Molecule, "ChemicalSystem"]],
legends: Optional[List[str]] = None,
molsPerRow: int = 2,
subImgSize: Tuple[int, int] = (200, 200),
ax: Optional["plt.Axes"] = None,
save_svg_path: Optional[Union[str, "PathLike"]] = None,
**kwargs: Any,
) -> Union["PilImage.Image", "plt.Axes", str]:
"""Plot series of molecules in a grid using RDKit
:param ax: if provided molecules are plotted in these axes, note that the quality of the image might reduce, defaults to None
:param save_svg_path: pathlike of the file, with formats .svg to save it as image, it returns the svg string, defaults to None
:return: an image of the molecules
:rtype: pil.Image or plt.Axes or string
"""
from rdkit.Chem import Draw, rdchem
from rdkit.Chem.Draw import IPythonConsole # type: ignore[attr-defined]
from scm.plams.interfaces.molecule.rdkit import _rdmol_for_image
# guess bonds, the bonds will be included in the RDKit molecule
for m in molecules:
if len(m.bonds) == 0:
m.guess_bonds()
molecules = [_rdmol_for_image(m, remove_hydrogens=False) for m in molecules]
if ax is not None or save_svg_path is not None:
if hasattr(rdchem.Mol, "_repr_svg_"):
IPythonConsole.UninstallIPythonRenderer()
else:
if not hasattr(rdchem.Mol, "_repr_svg_"):
IPythonConsole.InstallIPythonRenderer()
IPythonConsole.ipython_useSVG = True
if ax is not None:
IPythonConsole.ipython_useSVG = False
kwargs["useSVG"] = False
if save_svg_path is not None:
kwargs["useSVG"] = True
img = Draw.MolsToGridImage(
mols=molecules,
molsPerRow=molsPerRow, # Number of molecules per row
subImgSize=subImgSize, # Size of each individual image
legends=legends,
**kwargs,
)
if save_svg_path is not None:
if not isinstance(img, str):
raise TypeError(
f"{type(img)=} but expected str, most likely it is due to previously using ipy_useSVG=True in a notebook"
)
with open(save_svg_path, "w") as f:
f.write(img)
return img
if ax is not None and save_svg_path is None:
image_data = np.array(img, dtype=np.int32)
ax.imshow(image_data)
return ax
return img
[docs]@requires_optional_package("matplotlib")
def plot_image_grid(
images: Dict[str, "PilImage.Image"],
rows: Optional[int] = None,
cols: Optional[int] = None,
figsize: Optional[Tuple[float, float]] = None,
show_labels: bool = True,
save_path: Optional[Union[str, "PathLike"]] = None,
) -> np.ndarray:
"""Plot a dictionary of images in a matplotlib grid.
:param images: dictionary with labels as keys and images as values; iteration order determines image order in the grid
:param rows: number of rows in the grid; if ``None``, infer from ``cols`` and number of images
:param cols: number of columns in the grid; if ``None``, infer from ``rows`` and number of images
:param figsize: matplotlib figure size; if ``None``, uses a grid-proportional default
:param show_labels: whether to show labels above images; labels are taken from dictionary keys
:param save_path: optional path to save the plotted grid image using matplotlib ``savefig``
:return: 2D numpy array of matplotlib axes with shape ``(rows, cols)``
:rtype: np.ndarray
"""
import matplotlib.pyplot as plt
items = list(images.items())
n_images = len(items)
if n_images == 0:
raise ValueError("images must contain at least one image")
if rows is not None and rows <= 0:
raise ValueError(f"rows must be a positive integer when provided, but got {rows}")
if cols is not None and cols <= 0:
raise ValueError(f"cols must be a positive integer when provided, but got {cols}")
if rows is None and cols is None:
cols = int(np.ceil(np.sqrt(n_images)))
rows = int(np.ceil(n_images / cols))
elif rows is None:
rows = int(np.ceil(n_images / cols)) # type: ignore[operator]
elif cols is None:
cols = int(np.ceil(n_images / rows))
grid_size = rows * cols # type: ignore[operator]
if n_images > grid_size:
raise ValueError(f"Grid of shape ({rows}, {cols}) can hold at most {grid_size} images, but got {n_images}")
if figsize is None:
figsize = ((4.0 * cols), (4.0 * rows)) # type: ignore[operator]
fig, axes = plt.subplots(rows, cols, figsize=figsize) # type: ignore[arg-type]
axes = np.array(axes, dtype=object).reshape(rows, cols) # type: ignore[arg-type]
for ax in axes.flat:
ax.axis("off")
for i, (key, image) in enumerate(items):
row, col = divmod(i, cols) # type: ignore[operator]
ax = cast(Any, axes[row, col])
ax.imshow(image)
if show_labels:
ax.set_title(key)
if save_path is not None:
fig.savefig(save_path)
return axes
def get_correlation_xy(
job1: Union[AMSJob, List[AMSJob]],
job2: Union[AMSJob, List[AMSJob]],
section: str,
variable: str,
alt_section: Optional[str] = None,
alt_variable: Optional[str] = None,
file: str = "ams",
multiplier: float = 1.0,
) -> Tuple[np.ndarray, np.ndarray]:
def tolist(x: Any) -> List:
if isinstance(x, list):
return x
return [x]
job1 = tolist(job1)
job2 = tolist(job2)
alt_section = alt_section or section
alt_variable = alt_variable or variable
data1 = []
data2 = []
for j1, j2 in zip(job1, job2):
try:
d1 = cast(
Union[List[float], float],
j1.results.readrkf(section, variable, file=file),
)
except KeyError:
d1 = cast(
Union[List[float], float],
j1.results.get_history_property(variable, history_section=section),
)
d1a = np.ravel(d1) * multiplier
try:
d2 = cast(
Union[List[float], float],
j2.results.readrkf(alt_section, alt_variable, file=file),
)
except KeyError:
d2 = cast(
Union[List[float], float],
j2.results.get_history_property(alt_variable, history_section=alt_section),
)
d2a = np.ravel(d2) * multiplier
data1.extend(list(d1a))
data2.extend(list(d2a))
return np.array(data1), np.array(data2)
[docs]@requires_optional_package("matplotlib")
def plot_correlation(
job1: Union[AMSJob, List[AMSJob]],
job2: Union[AMSJob, List[AMSJob]],
section: str,
variable: str,
alt_section: Optional[str] = None,
alt_variable: Optional[str] = None,
file: str = "ams",
multiplier: float = 1.0,
unit: Optional[str] = None,
save_txt: Optional[str] = None,
ax: Optional["plt.Axes"] = None,
show_xy: bool = True,
show_linear_fit: bool = True,
show_mad: bool = True,
show_rmsd: bool = True,
xlabel: Optional[str] = None,
ylabel: Optional[str] = None,
) -> "plt.Axes":
"""
Plot a correlation plot from AMS .rkf files
job1: AMSJob or List[AMSJob]
Job(s) plotted on x-axis
job2: AMSJob or List[AMSJob]
job2: Job(s) plotted on y-axis
section: str
section: section to read on .rkf files
variable: str
variable: variable to read
alt_section: str
Section to read on .rkf files for job2. If not specified it will be the same as ``section``
alt_variable : str
Variable to read for job2. If not specified it will be the same as ``variable``.
file: str, optional
file: "ams" or "engine", defaults to "ams"
multiplier: float, optional
multiplier: Numbers will be multiplied by this number, defaults to 1.0
unit: str, optional
unit: unit will be shown in the plot, defaults to None
save_txt: str, optional
save_txt: If not None, save the xy data to this text file, defaults to None
ax: matplotlib axis, optional
ax: matplotlib axis, defaults to None
show_xy: bool, optional
show_xy: Whether to show y=x line, defaults to True
show_linear_fit: bool, optional
show_linear_fit: Whether to perform and show a linear fit, defaults to True
show_mad: bool, optional
show_mad: Whether to show mean absolute deviation, defaults to True
show_rmsd: bool, optional
show_rmsd: Whether to show root-mean-square deviation, defaults to True
xlabel: str, optional
xlabel: The x-label. If not given will be a list of job names, defaults to None
ylabel: str, optional
ylabel: THe y-label. If not given will be al ist of job names, defaults to None
Returns: A matplotlib axis
"""
import matplotlib.pyplot as plt
def tolist(x: Any) -> List:
if isinstance(x, list):
return x
return [x]
job1 = tolist(job1)
job2 = tolist(job2)
alt_section = alt_section or section
alt_variable = alt_variable or variable
data1, data2 = get_correlation_xy(job1, job2, section, variable, alt_section, alt_variable, file, multiplier)
def add_unit(s: str) -> str:
if unit is not None:
return f"{s} ({unit})"
return s
if ax is None:
fig, ax = plt.subplots()
complete_data = np.stack((data1, data2), axis=1)
min_data = np.min(complete_data)
max_data = np.max(complete_data)
min_max = np.array([min_data, max_data])
legend = []
title = [f"{section}%{variable}"]
if show_xy:
ax.plot(min_max, min_max, "-")
legend.append("y=x")
stats_title = ""
if show_mad:
mad = np.mean(np.abs(data2 - data1))
stats_title += add_unit(f" MAD: {mad:.5f}")
if show_rmsd:
rmsd = np.sqrt(np.mean((data2 - data1) ** 2))
stats_title += add_unit(f" RMSD: {rmsd:.5f}")
linear_fit_title = None
if show_linear_fit:
try:
from scipy.stats import linregress
except ImportError:
raise MissingOptionalPackageError("scipy")
result = linregress(data1, data2)
min_max_linear_fit = result.slope * min_max + result.intercept
r2 = result.rvalue**2
ax.plot(min_max, min_max_linear_fit, "-")
legend.append("Fit")
stats_title += f" R^2: {r2:.3f}"
linear_fit_title = f"Linear fit slope={result.slope:.3f} intercept={result.intercept:.3f}"
if stats_title:
title.append(stats_title)
if linear_fit_title:
title.append(linear_fit_title)
ax.plot(data1, data2, ".")
legend.append("data")
if xlabel is None:
xlabel = ", ".join(x.name for x in job1)
if len(xlabel) > 40:
xlabel = xlabel[:35] + "..."
if ylabel is None:
ylabel = ", ".join(x.name for x in job2)
if len(ylabel) > 40:
ylabel = ylabel[:35] + "..."
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_title("\n".join(title))
ax.legend(legend)
ax.set_box_aspect(1)
ax.set_xlim(*min_max)
ax.set_ylim(*min_max)
if save_txt is not None:
np.savetxt(save_txt, complete_data, header=f"{xlabel} {ylabel}")
return ax
[docs]@requires_optional_package("matplotlib")
def plot_msd(
job: "AMSMSDJob",
start_time_fit_fs: Optional[float] = None,
ax: Optional["plt.Axes"] = None,
) -> "plt.Axes":
"""
job: AMSMSDJob
The job for which to plot the results
start_time_fit_fs: float
The start time (in fs) for which to perform the linear fit
ax: matplotlib axis
The axis. If None, one will be created
Returns: matplotlib axis
"""
import matplotlib.pyplot as plt
if ax is None:
_, ax = plt.subplots()
time, msd = job.results.get_msd()
fit_result, fit_x, fit_y = job.results.get_linear_fit(start_time_fit_fs=start_time_fit_fs)
# the diffusion coefficient can also be calculated as fit_result.slope/6 (ang^2/fs)
diffusion_coefficient = job.results.get_diffusion_coefficient(start_time_fit_fs=start_time_fit_fs) # m^2/s
ax.plot(time, msd, label="MSD")
ax.plot(fit_x, fit_y, label=f"Linear fit slope={fit_result.slope:.5f} ang^2/fs")
ax.legend()
ax.set_xlabel("Correlation time (fs)")
ax.set_ylabel("Mean square displacement (ang^2)")
ax.set_title(f"MSD: Diffusion coefficient = {diffusion_coefficient:.2e} m^2/s")
return ax
[docs]@requires_optional_package("matplotlib")
def plot_work_function(
coordinate: np.ndarray,
planarAverage: np.ndarray,
macroscopicAverage: np.ndarray,
Efermi: float,
Vbulk: float,
Vvacuum: Tuple[float, float],
WF: Tuple[float, float],
ax: Optional["plt.Axes"] = None,
) -> "plt.Axes":
"""
Plots an Electrostatic Potential Profile from AMS-QE with matplotlib.
To control the appearance of the plot you need to call ``plt.ylim(bottom, top)``, ``plt.title(title)``, etc.
manually outside this function.
``coordinate``: 1D array of float.
Returned by AMSResults.get_work_function_results().
``planarAverage``: 1D array of float.
Returned by AMSResults.get_work_function_results().
``macroscopicAverage``: 1D array of float.
Returned by AMSResults.get_work_function_results(). Should have the same unit as ``planarAverage``.
``Efermi``: float.
Returned by AMSResults.get_work_function_results(). Should have the same unit as ``planarAverage``.
``Vbulk``: float.
Returned by AMSResults.get_work_function_results(). Should have the same unit as ``planarAverage``.
``Vvacuum``: Tuple[float,float].
Returned by AMSResults.get_work_function_results(). Should have the same unit as ``planarAverage``.
``WF``: Tuple[float,float].
Returned by AMSResults.get_work_function_results(). Should have the same unit as ``planarAverage``.
Additional parameters:
``ax``: matplotlib axis
The axis. If None, one will be created
Returns: matplotlib axis
"""
import matplotlib.pyplot as plt
if ax is None:
_, ax = plt.subplots()
ax.set_xlabel("Length", fontsize=13)
ax.set_ylabel("Energy", fontsize=13)
x0 = min(coordinate)
y0 = min(planarAverage)
x1 = max(coordinate)
ax.plot(coordinate, planarAverage, color="red", linestyle="-.", lw=2, zorder=1)
ax.plot(coordinate, macroscopicAverage, color="blue", linestyle="-", lw=2, zorder=2)
ax.text(x0 + 0.8 * (x1 - x0), y0, "Planar\nAverage", fontsize=11, color="red")
ax.text(x0 + 0.0 * (x1 - x0), y0, "Macroscopic\nAverage", fontsize=11, color="blue")
ax.axhline(y=Efermi, color="black", linestyle="dashed", linewidth=1)
ax.text(x0 + 0.05 * (x1 - x0), Efermi + 0.1, "E. Fermi", fontsize=11, color="black")
ax.axhline(y=Vbulk, color="black", linestyle="dashed", linewidth=1)
ax.text(x0 + 0.05 * (x1 - x0), Vbulk + 0.1, "Pot. bulk", fontsize=11, color="black")
# If the material is symmetric:
if abs(Vvacuum[0] - Vvacuum[1]) < 1e-3 or abs(Vvacuum[0] - Vvacuum[1]) < 1e-3:
ax.axhline(y=Vvacuum[0], color="black", linestyle="dashed", linewidth=1)
ax.text(min(coordinate), Vvacuum[0] + 0.1, "Pot. vacuum", fontsize=11, color="black")
head_length = 0.4
ax.arrow(
x0 + 1.0 * (x1 - x0),
Efermi,
0.0,
Vvacuum[1] - Efermi - head_length,
head_width=0.3,
head_length=head_length,
fc="black",
ec="black",
)
ax.text(
x0 + 0.98 * (x1 - x0),
(Vvacuum[1] + Efermi) / 2,
f"WF={WF[1]:.1f} eV",
fontsize=11,
color="black",
horizontalalignment="right",
)
# Otherwise:
else:
ax.plot(
[x0, x0 + 0.3 * (x1 - x0)],
[Vvacuum[0], Vvacuum[0]],
color="black",
linestyle="dashed",
linewidth=1,
)
ax.text(x0, Vvacuum[0] + 0.1, "Pot. vacuum", fontsize=11, color="black")
ax.plot(
[x1, x1 - 0.3 * (x1 - x0)],
[Vvacuum[1], Vvacuum[1]],
color="black",
linestyle="dashed",
linewidth=1,
)
ax.text(
x1 - 0.3 * (x1 - x0),
Vvacuum[1] + 0.1,
"Pot. vacuum",
fontsize=11,
color="black",
)
head_length = 0.4
ax.arrow(
x0 + 0.0 * (x1 - x0),
Efermi,
0.0,
Vvacuum[0] - Efermi - head_length,
head_width=0.3,
head_length=head_length,
fc="black",
ec="black",
)
ax.text(
x0 + 0.02 * (x1 - x0),
(Vvacuum[0] + Efermi) / 2,
f"WF={WF[0]:.1f} eV",
fontsize=11,
color="black",
)
ax.arrow(
x0 + 1.0 * (x1 - x0),
Efermi,
0.0,
Vvacuum[1] - Efermi - head_length,
head_width=0.3,
head_length=head_length,
fc="black",
ec="black",
)
ax.text(
x0 + 0.98 * (x1 - x0),
(Vvacuum[1] + Efermi) / 2,
f"WF={WF[1]:.1f} eV",
fontsize=11,
color="black",
horizontalalignment="right",
)
return ax