Source code for scm.plams.tools.plot

from typing import List, Optional, Tuple, Union, TYPE_CHECKING

import numpy as np
from scm.plams.core.errors import MissingOptionalPackageError
from scm.plams.core.functions import requires_optional_package
from scm.plams.interfaces.adfsuite.ams import AMSJob
from scm.plams.mol.molecule import Molecule

if TYPE_CHECKING:
    import matplotlib.pyplot as plt
    from os import PathLike
    from PIL import Image as PilImage

__all__ = [
    "plot_band_structure",
    "plot_phonons_band_structure",
    "plot_phonons_dos",
    "plot_phonons_thermodynamic_properties",
    "plot_molecule",
    "plot_correlation",
    "plot_msd",
    "plot_work_function",
    "plot_grid_molecules",
]


[docs]@requires_optional_package("matplotlib") def plot_band_structure(x, y_spin_up, y_spin_down=None, labels=None, fermi_energy=None, zero=None, ax=None): """ Plots an electronic band structure from DFTB, BAND, or QuantumEspresso engines with matplotlib. To control the appearance of the plot you need to call ``plt.ylim(bottom, top)``, ``plt.title(title)``, etc. manually outside this function. x: list of float Returned by AMSResults.get_band_structure() y_spin_up: 2D numpy array of float Returned by AMSResults.get_band_structure() y_spin_down: 2D numpy array of float. If None, the spin down bands are not plotted. Returned by AMSResults.get_band_structure() labels: list of str Returned by AMSResults.get_band_structure() fermi_energy: float Returned by AMSResults.get_band_structure(). Should have the same unit as ``y``. zero: None or float or one of 'fermi', 'vbmax', 'cbmin' Shift the curves so that y=0 is at the specified value. If None, no shift is performed. 'fermi', 'vbmax', and 'cbmin' require that the ``fermi_energy`` is not None. Note: 'vbmax' and 'cbmin' calculate the zero as the highest (lowest) eigenvalue smaller (greater) than or equal to ``fermi_energy``. This is NOT necessarily equal to the valence band maximum or conduction band minimum as calculated by the compute engine. Additional parameters: ``ax``: matplotlib axis The axis. If None, one will be created """ import matplotlib.pyplot as plt if zero is None: zero = 0 elif zero == "fermi": assert fermi_energy is not None zero = fermi_energy elif zero in ["vbm", "vbmax"]: assert fermi_energy is not None zero = y_spin_up[y_spin_up <= fermi_energy].max() if y_spin_down is not None: zero = max(zero, y_spin_down[y_spin_down <= fermi_energy].max()) elif zero in ["cbm", "cbmax"]: assert fermi_energy is not None zero = y_spin_up[y_spin_up >= fermi_energy].min() if y_spin_down is not None: zero = min(zero, y_spin_down[y_spin_down <= fermi_energy].min()) labels = labels or [] for i, label in enumerate(labels): if label: label = ( label.replace("GAMMA", "\\Gamma") .replace("DELTA", "\\Delta") .replace("LAMBDA", "\\Lambda") .replace("SIGMA", "\\Sigma") ) labels[i] = f"${label}$" if ax is None: _, ax = plt.subplots() ax.plot(x, y_spin_up - zero, "-") if y_spin_down is not None: ax.plot(x, y_spin_down - zero, "--") tick_x: List[float] = [] tick_labels: List[str] = [] for xx, ll in zip(x, labels): if ll: if len(tick_x) == 0: tick_x.append(xx) tick_labels.append(ll) continue if np.isclose(xx, tick_x[-1]): if ll != tick_labels[-1]: tick_labels[-1] += f",{ll}" else: tick_x.append(xx) tick_labels.append(ll) for xx in tick_x: ax.axvline(xx) if fermi_energy is not None: ax.axhline(fermi_energy - zero, linestyle="--") ax.set_xticks(ticks=tick_x, labels=tick_labels) return ax
[docs]@requires_optional_package("matplotlib") def plot_phonons_band_structure(x, y, labels=None, zero=None, ax=None): """ Plots a phonons band structure from DFTB, BAND or QuantumEspresso engines with matplotlib. To control the appearance of the plot you need to call ``plt.ylim(bottom, top)``, ``plt.title(title)``, etc. manually outside this function. x: list of float Returned by AMSResults.get_phonons_band_structure() y: 2D numpy array of float Returned by AMSResults.get_phonons_band_structure() labels: list of str Returned by AMSResults.get_phonons_band_structure() zero: None or float Shift the curves so that y=0 is at the specified value. If None, no shift is performed. Additional parameters: ``ax``: matplotlib axis The axis. If None, one will be created """ import matplotlib.pyplot as plt if zero is None: zero = 0 labels = labels or [] for i, label in enumerate(labels): if label: label = ( label.replace("GAMMA", "\\Gamma") .replace("DELTA", "\\Delta") .replace("LAMBDA", "\\Lambda") .replace("SIGMA", "\\Sigma") ) labels[i] = f"${label}$" if ax is None: _, ax = plt.subplots() ax.plot(x, y - zero, "-") tick_x: List[float] = [] tick_labels: List[str] = [] for xx, ll in zip(x, labels): if ll: if len(tick_x) == 0: tick_x.append(xx) tick_labels.append(ll) continue if np.isclose(xx, tick_x[-1]): if ll != tick_labels[-1]: tick_labels[-1] += f",{ll}" else: tick_x.append(xx) tick_labels.append(ll) for xx in tick_x: ax.axvline(xx, dashes=[2, 2], color="gray") ax.set_xticks(ticks=tick_x, labels=tick_labels) return ax
[docs]@requires_optional_package("matplotlib") def plot_phonons_dos(energy, total_dos, dos_per_species, dos_per_atom, dos_type="total", ax=None): """ Plots the phonons DOS from DFTB, BAND or QuantumEspresso engines with matplotlib. To control the appearance of the plot you need to call ``plt.ylim(bottom, top)``, ``plt.title(title)``, etc. manually outside this function. energy: list of float Returned by AMSResults.get_phonons_thermodynamic_properties() total_dos: list of float Returned by AMSResults.get_phonons_thermodynamic_properties() dos_per_species: dictionary of list of float Returned by AMSResults.get_phonons_thermodynamic_properties() dos_per_atom: dictionary of list of float Returned by AMSResults.get_phonons_thermodynamic_properties() dos_type: str Specifies the kind of plot to show. Possible options: - "total": Total DOS. - "species": DOS decomposed by species. - "atom": DOS decomposed by atom. Additional parameters: ``ax``: matplotlib axis The axis. If None, one will be created """ import matplotlib.pyplot as plt if ax is None: _, ax = plt.subplots() if dos_type == "total": ax.plot(energy, total_dos, color="black", label="Total DOS", linestyle="-", zorder=1) elif dos_type == "species": ax.plot(energy, total_dos, color="black", label="Total DOS", linestyle="-", zorder=-1) for i, (l, v) in enumerate(dos_per_species.items()): ax.plot(energy, v, label=f"pDOS {l}", dashes=[3, i + 1, 2], zorder=i) elif dos_type == "atoms": ax.plot(energy, total_dos, color="black", label="Total DOS", linestyle="-", zorder=-1) for i, (l, v) in enumerate(dos_per_atom.items()): ax.plot(energy, v, label=f"pDOS {l}", dashes=[3, i + 1, 2], zorder=i) else: raise ValueError("Invalid dos_type. Must be 'total', 'species', or 'atom'.") plt.legend() return ax
[docs]@requires_optional_package("matplotlib") def plot_phonons_thermodynamic_properties(temperature, properties, units, ax=None): """ Plots the phonons thermodynamic properties from DFTB, BAND or QuantumEspresso engines with matplotlib. To control the appearance of the plot you need to call ``plt.ylim(bottom, top)``, ``plt.title(title)``, etc. manually outside this function. temperature: list of float Returned by AMSResults.get_phonons_thermodynamic_properties() properties: dictionary of list of float Returned by AMSResults.get_phonons_thermodynamic_properties() units: dictionary of str Returned by AMSResults.get_phonons_thermodynamic_properties() Additional parameters: ``ax``: matplotlib axis The axis. If None, one will be created """ import matplotlib.pyplot as plt if ax is None: _, ax = plt.subplots() for i, (label, prop) in enumerate(properties.items()): ax.plot(temperature, prop, label=label + " (" + units[label] + ")", linestyle="-", lw=2, zorder=1) plt.legend() return ax
[docs]@requires_optional_package("matplotlib") @requires_optional_package("ase") def plot_molecule(molecule, figsize=None, ax=None, keep_axis: bool = False, **kwargs): """Show a molecule in a Jupyter notebook""" import matplotlib.pyplot as plt from ase.visualize.plot import plot_atoms from scm.plams.interfaces.molecule.ase import toASE if isinstance(molecule, Molecule): molecule = toASE(molecule) if ax is None: _, ax = plt.subplots(figsize=figsize or (2, 2)) plot_atoms(molecule, ax=ax, **kwargs) if not keep_axis: ax.axis("off") return ax
[docs]@requires_optional_package("rdkit") def plot_grid_molecules( molecules: List[Molecule], legends: Optional[List[str]] = None, molsPerRow: int = 2, subImgSize: Tuple[int, int] = (200, 200), ax: Optional["plt.Axes"] = None, save_svg_path: Optional[Union[str, "PathLike"]] = None, **kwargs, ) -> Union["PilImage", "plt.Axes", str]: """Plot series of molecules in a grid using RDKit :param ax: if provided molecules are plotted in these axes, note that the quality of the image might reduce, defaults to None :param save_svg_path: pathlike of the file, with formats .svg to save it as image, it returns the svg string, defaults to None :return: an image of the molecules :rtype: pil.Image or plt.Axes or string """ from rdkit.Chem import Draw, rdchem from rdkit.Chem.Draw import IPythonConsole from scm.plams.interfaces.molecule.rdkit import _rdmol_for_image # guess bonds, the bonds will be included in the RDKit molecule for m in molecules: if len(m.bonds) == 0: m.guess_bonds() molecules = [_rdmol_for_image(m, remove_hydrogens=False) for m in molecules] if ax is not None or save_svg_path is not None: if hasattr(rdchem.Mol, "_repr_svg_"): IPythonConsole.UninstallIPythonRenderer() else: if not hasattr(rdchem.Mol, "_repr_svg_"): IPythonConsole.InstallIPythonRenderer() IPythonConsole.ipython_useSVG = True if ax is not None: IPythonConsole.ipython_useSVG = False kwargs["useSVG"] = False if save_svg_path is not None: kwargs["useSVG"] = True img = Draw.MolsToGridImage( mols=molecules, molsPerRow=molsPerRow, # Number of molecules per row subImgSize=subImgSize, # Size of each individual image legends=legends, **kwargs, ) if save_svg_path is not None: if not isinstance(img, str): raise TypeError( f"{type(img)=} but expected str, most likely it is due to previously using ipy_useSVG=True in a notebook" ) with open(save_svg_path, "w") as f: f.write(img) return img if ax is not None and save_svg_path is None: image_data = np.array(img, dtype=np.int32) ax.imshow(image_data) return ax return img
def get_correlation_xy( job1: Union[AMSJob, List[AMSJob]], job2: Union[AMSJob, List[AMSJob]], section: str, variable: str, alt_section: Optional[str] = None, alt_variable: Optional[str] = None, file: str = "ams", multiplier: float = 1.0, ) -> Tuple: def tolist(x): if isinstance(x, list): return x return [x] job1 = tolist(job1) job2 = tolist(job2) alt_section = alt_section or section alt_variable = alt_variable or variable data1 = [] data2 = [] for j1, j2 in zip(job1, job2): try: d1 = j1.results.readrkf(section, variable, file=file) except KeyError: d1 = j1.results.get_history_property(variable, history_section=section) d1 = np.ravel(d1) * multiplier try: d2 = j2.results.readrkf(alt_section, alt_variable, file=file) except KeyError: d2 = j2.results.get_history_property(alt_variable, history_section=alt_section) d2 = np.ravel(d2) * multiplier data1.extend(list(d1)) data2.extend(list(d2)) return np.array(data1), np.array(data2)
[docs]@requires_optional_package("matplotlib") def plot_correlation( job1: Union[AMSJob, List[AMSJob]], job2: Union[AMSJob, List[AMSJob]], section: str, variable: str, alt_section: Optional[str] = None, alt_variable: Optional[str] = None, file: str = "ams", multiplier: float = 1.0, unit: Optional[str] = None, save_txt: Optional[str] = None, ax=None, show_xy: bool = True, show_linear_fit: bool = True, show_mad: bool = True, show_rmsd: bool = True, xlabel: Optional[str] = None, ylabel: Optional[str] = None, ): """ Plot a correlation plot from AMS .rkf files job1: AMSJob or List[AMSJob] Job(s) plotted on x-axis job2: AMSJob or List[AMSJob] job2: Job(s) plotted on y-axis section: str section: section to read on .rkf files variable: str variable: variable to read alt_section: str Section to read on .rkf files for job2. If not specified it will be the same as ``section`` alt_variable : str Variable to read for job2. If not specified it will be the same as ``variable``. file: str, optional file: "ams" or "engine", defaults to "ams" multiplier: float, optional multiplier: Numbers will be multiplied by this number, defaults to 1.0 unit: str, optional unit: unit will be shown in the plot, defaults to None save_txt: str, optional save_txt: If not None, save the xy data to this text file, defaults to None ax: matplotlib axis, optional ax: matplotlib axis, defaults to None show_xy: bool, optional show_xy: Whether to show y=x line, defaults to True show_linear_fit: bool, optional show_linear_fit: Whether to perform and show a linear fit, defaults to True show_mad: bool, optional show_mad: Whether to show mean absolute deviation, defaults to True show_rmsd: bool, optional show_rmsd: Whether to show root-mean-square deviation, defaults to True xlabel: str, optional xlabel: The x-label. If not given will be a list of job names, defaults to None ylabel: str, optional ylabel: THe y-label. If not given will be al ist of job names, defaults to None Returns: A matplotlib axis """ import matplotlib.pyplot as plt def tolist(x): if isinstance(x, list): return x return [x] job1 = tolist(job1) job2 = tolist(job2) alt_section = alt_section or section alt_variable = alt_variable or variable data1, data2 = get_correlation_xy(job1, job2, section, variable, alt_section, alt_variable, file, multiplier) def add_unit(s: str): if unit is not None: return f"{s} ({unit})" return s if ax is None: fig, ax = plt.subplots() complete_data = np.stack((data1, data2), axis=1) min_data = np.min(complete_data) max_data = np.max(complete_data) min_max = np.array([min_data, max_data]) legend = [] title = [f"{section}%{variable}"] if show_xy: ax.plot(min_max, min_max, "-") legend.append("y=x") stats_title = "" if show_mad: mad = np.mean(np.abs(data2 - data1)) stats_title += add_unit(f" MAD: {mad:.5f}") if show_rmsd: rmsd = np.sqrt(np.mean((data2 - data1) ** 2)) stats_title += add_unit(f" RMSD: {rmsd:.5f}") linear_fit_title = None if show_linear_fit: try: from scipy.stats import linregress except ImportError: raise MissingOptionalPackageError("scipy") result = linregress(data1, data2) min_max_linear_fit = result.slope * min_max + result.intercept r2 = result.rvalue**2 ax.plot(min_max, min_max_linear_fit, "-") legend.append("Fit") stats_title += f" R^2: {r2:.3f}" linear_fit_title = f"Linear fit slope={result.slope:.3f} intercept={result.intercept:.3f}" if stats_title: title.append(stats_title) if linear_fit_title: title.append(linear_fit_title) ax.plot(data1, data2, ".") legend.append("data") if xlabel is None: xlabel = ", ".join(x.name for x in job1) if len(xlabel) > 40: xlabel = xlabel[:35] + "..." if ylabel is None: ylabel = ", ".join(x.name for x in job2) if len(ylabel) > 40: ylabel = ylabel[:35] + "..." ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) ax.set_title("\n".join(title)) ax.legend(legend) ax.set_box_aspect(1) ax.set_xlim(*min_max) ax.set_ylim(*min_max) if save_txt is not None: np.savetxt(save_txt, complete_data, header=f"{xlabel} {ylabel}") return ax
[docs]@requires_optional_package("matplotlib") def plot_msd(job, start_time_fit_fs=None, ax=None): """ job: AMSMSDJob The job for which to plot the results start_time_fit_fs: float The start time (in fs) for which to perform the linear fit ax: matplotlib axis The axis. If None, one will be created Returns: matplotlib axis """ import matplotlib.pyplot as plt if ax is None: _, ax = plt.subplots() time, msd = job.results.get_msd() fit_result, fit_x, fit_y = job.results.get_linear_fit(start_time_fit_fs=start_time_fit_fs) # the diffusion coefficient can also be calculated as fit_result.slope/6 (ang^2/fs) diffusion_coefficient = job.results.get_diffusion_coefficient(start_time_fit_fs=start_time_fit_fs) # m^2/s ax.plot(time, msd, label="MSD") ax.plot(fit_x, fit_y, label="Linear fit slope={:.5f} ang^2/fs".format(fit_result.slope)) ax.legend() ax.set_xlabel("Correlation time (fs)") ax.set_ylabel("Mean square displacement (ang^2)") ax.set_title("MSD: Diffusion coefficient = {:.2e} m^2/s".format(diffusion_coefficient)) return ax
[docs]@requires_optional_package("matplotlib") def plot_work_function( coordinate: np.ndarray, planarAverage: np.ndarray, macroscopicAverage: np.ndarray, Efermi: float, Vbulk: float, Vvacuum: Tuple[float, float], WF: Tuple[float, float], ax=None, ): """ Plots an Electrostatic Potential Profile from AMS-QE with matplotlib. To control the appearance of the plot you need to call ``plt.ylim(bottom, top)``, ``plt.title(title)``, etc. manually outside this function. ``coordinate``: 1D array of float. Returned by AMSResults.get_work_function_results(). ``planarAverage``: 1D array of float. Returned by AMSResults.get_work_function_results(). ``macroscopicAverage``: 1D array of float. Returned by AMSResults.get_work_function_results(). Should have the same unit as ``planarAverage``. ``Efermi``: float. Returned by AMSResults.get_work_function_results(). Should have the same unit as ``planarAverage``. ``Vbulk``: float. Returned by AMSResults.get_work_function_results(). Should have the same unit as ``planarAverage``. ``Vvacuum``: Tuple[float,float]. Returned by AMSResults.get_work_function_results(). Should have the same unit as ``planarAverage``. ``WF``: Tuple[float,float]. Returned by AMSResults.get_work_function_results(). Should have the same unit as ``planarAverage``. Additional parameters: ``ax``: matplotlib axis The axis. If None, one will be created Returns: matplotlib axis """ import matplotlib.pyplot as plt if ax is None: _, ax = plt.subplots() ax.set_xlabel("Length", fontsize=13) ax.set_ylabel("Energy", fontsize=13) x0 = min(coordinate) y0 = min(planarAverage) x1 = max(coordinate) ax.plot(coordinate, planarAverage, color="red", linestyle="-.", lw=2, zorder=1) ax.plot(coordinate, macroscopicAverage, color="blue", linestyle="-", lw=2, zorder=2) ax.text(x0 + 0.8 * (x1 - x0), y0, "Planar\nAverage", fontsize=11, color="red") ax.text(x0 + 0.0 * (x1 - x0), y0, "Macroscopic\nAverage", fontsize=11, color="blue") ax.axhline(y=Efermi, color="black", linestyle="dashed", linewidth=1) ax.text(x0 + 0.05 * (x1 - x0), Efermi + 0.1, "E. Fermi", fontsize=11, color="black") ax.axhline(y=Vbulk, color="black", linestyle="dashed", linewidth=1) ax.text(x0 + 0.05 * (x1 - x0), Vbulk + 0.1, "Pot. bulk", fontsize=11, color="black") # If the material is symmetric: if abs(Vvacuum[0] - Vvacuum[1]) < 1e-3 or abs(Vvacuum[0] - Vvacuum[1]) < 1e-3: ax.axhline(y=Vvacuum[0], color="black", linestyle="dashed", linewidth=1) ax.text(min(coordinate), Vvacuum[0] + 0.1, "Pot. vacuum", fontsize=11, color="black") head_length = 0.4 ax.arrow( x0 + 1.0 * (x1 - x0), Efermi, 0.0, Vvacuum[1] - Efermi - head_length, head_width=0.3, head_length=head_length, fc="black", ec="black", ) ax.text( x0 + 0.98 * (x1 - x0), (Vvacuum[1] + Efermi) / 2, "WF=" + "%.1f" % WF[1] + " eV", fontsize=11, color="black", horizontalalignment="right", ) # Otherwise: else: ax.plot([x0, x0 + 0.3 * (x1 - x0)], [Vvacuum[0], Vvacuum[0]], color="black", linestyle="dashed", linewidth=1) ax.text(x0, Vvacuum[0] + 0.1, "Pot. vacuum", fontsize=11, color="black") ax.plot([x1, x1 - 0.3 * (x1 - x0)], [Vvacuum[1], Vvacuum[1]], color="black", linestyle="dashed", linewidth=1) ax.text(x1 - 0.3 * (x1 - x0), Vvacuum[1] + 0.1, "Pot. vacuum", fontsize=11, color="black") head_length = 0.4 ax.arrow( x0 + 0.0 * (x1 - x0), Efermi, 0.0, Vvacuum[0] - Efermi - head_length, head_width=0.3, head_length=head_length, fc="black", ec="black", ) ax.text( x0 + 0.02 * (x1 - x0), (Vvacuum[0] + Efermi) / 2, "WF=" + "%.1f" % WF[0] + " eV", fontsize=11, color="black" ) ax.arrow( x0 + 1.0 * (x1 - x0), Efermi, 0.0, Vvacuum[1] - Efermi - head_length, head_width=0.3, head_length=head_length, fc="black", ec="black", ) ax.text( x0 + 0.98 * (x1 - x0), (Vvacuum[1] + Efermi) / 2, "WF=" + "%.1f" % WF[1] + " eV", fontsize=11, color="black", horizontalalignment="right", ) return ax