from dataclasses import dataclass from decimal import Decimal from typing import List, Tuple import matplotlib.pyplot as plt import numpy as np import pandas as pd @dataclass class Compound: name: str coskf: str def build_simplex_grid_points( n_components: int, min_frac: float = 0.1, step: float = 0.1, ) -> List[Tuple[float, ...]]: min_dec = Decimal(str(min_frac)) step_dec = Decimal(str(step)) if n_components < 2: raise ValueError("n_components must be >= 2.") if min_dec <= 0 or step_dec <= 0: raise ValueError("min_frac and step must be > 0.") if n_components * min_dec > 1: raise ValueError("n_components * min_frac must be <= 1.") precision = max(-min_dec.as_tuple().exponent, -step_dec.as_tuple().exponent) scale = 10**precision total_units = scale min_units = int(min_dec * scale) step_units = int(step_dec * scale) if Decimal(min_units) / scale != min_dec: raise ValueError("min_frac is incompatible with the chosen decimal precision.") if Decimal(step_units) / scale != step_dec: raise ValueError("step is incompatible with the chosen decimal precision.") def recurse(remaining_components: int, remaining_units: int) -> List[Tuple[int, ...]]: if remaining_components == 1: if remaining_units >= min_units and remaining_units % step_units == 0: return [(remaining_units,)] return [] points: List[Tuple[int, ...]] = [] max_units = remaining_units - min_units * (remaining_components - 1) for units in range(min_units, max_units + 1, step_units): for tail in recurse(remaining_components - 1, remaining_units - units): points.append((units,) + tail) return points integer_points = recurse(n_components, total_units) return [tuple(float(x / scale) for x in point) for point in integer_points] def build_ternary_grid_df( compounds: List[Compound], min_frac: float = 0.1, step: float = 0.1, temperature: float = 298.15, method: str = "COSMORS", ) -> pd.DataFrame: rows = [] for mix_idx, point in enumerate(build_simplex_grid_points(3, min_frac=min_frac, step=step)): rows.append( { "Method": method, "Temperature": temperature, "ncomp": 3, "sys_idx": 0, "mix_idx": mix_idx, "s1": compounds[0].name, "s2": compounds[1].name, "s3": compounds[2].name, "x1": point[0], "x2": point[1], "x3": point[2], "unstable": pd.NA, "converged": pd.NA, "llle_detected": pd.NA, "xI1": pd.NA, "xI2": pd.NA, "xI3": pd.NA, "xII1": pd.NA, "xII2": pd.NA, "xII3": pd.NA, } ) return pd.DataFrame(rows) def build_ternary_lle_df( compounds: List[Compound], s1_ratio: float, s2_ratio: float, s3_min: float, s3_max: float, num_points: int, temperature: float = 298.15, method: str = "COSMORS", ) -> pd.DataFrame: total = 1.0 if s1_ratio <= 0 or s2_ratio <= 0: raise ValueError("s1_ratio and s2_ratio must be > 0.") if num_points < 2: raise ValueError("num_points must be >= 2.") if s3_min < 0 or s3_max > total or s3_min > s3_max: raise ValueError("Require 0 <= s3_min <= s3_max <= 1.") ratio_total = s1_ratio + s2_ratio s1_share = s1_ratio / ratio_total s2_share = s2_ratio / ratio_total rows = [] for mix_idx, x3 in enumerate(np.linspace(s3_min, s3_max, num_points)): remaining = total - x3 x1 = remaining * s1_share x2 = remaining * s2_share if any(value < -1e-12 or value > total + 1e-12 for value in (x1, x2, x3)): raise ValueError("Generated composition is outside the [0, 1] interval.") rows.append( { "Method": method, "Temperature": temperature, "ncomp": 3, "sys_idx": 0, "mix_idx": mix_idx, "s1": compounds[0].name, "s2": compounds[1].name, "s3": compounds[2].name, "x1": float(x1), "x2": float(x2), "x3": float(x3), "unstable": pd.NA, "converged": pd.NA, "llle_detected": pd.NA, "xI1": pd.NA, "xI2": pd.NA, "xI3": pd.NA, "xII1": pd.NA, "xII2": pd.NA, "xII3": pd.NA, } ) return pd.DataFrame(rows) def _collect_feed_points(df: pd.DataFrame) -> List[Tuple[float, float, float]]: points: List[Tuple[float, float, float]] = [] for row in df.itertuples(): points.append((row.x1, row.x2, row.x3)) return points def _collect_phase_points(df: pd.DataFrame, prefix: str) -> List[Tuple[float, float, float]]: points: List[Tuple[float, float, float]] = [] for row in df.itertuples(): points.append((getattr(row, f"{prefix}1"), getattr(row, f"{prefix}2"), getattr(row, f"{prefix}3"))) return points def _collect_tie_lines(df: pd.DataFrame) -> List[Tuple[Tuple[float, float, float], Tuple[float, float, float]]]: tie_lines: List[Tuple[Tuple[float, float, float], Tuple[float, float, float]]] = [] for row in df.itertuples(): tie_lines.append(((row.xI1, row.xI2, row.xI3), (row.xII1, row.xII2, row.xII3))) return tie_lines def configure_tax(tax, labels: List[str], title: str) -> None: fontsize = 10 tax.boundary(linewidth=2.0) tax.gridlines(color="black", multiple=0.05, linewidth=0.5, alpha=0.4) tax.bottom_axis_label(f"{labels[0]}(1)", fontsize=fontsize, offset=0.1) tax.right_axis_label(f"{labels[1]}(2)", fontsize=fontsize, offset=0.1) tax.left_axis_label(f"{labels[2]}(3)", fontsize=fontsize, offset=0.1) tax.ticks( axis="lbr", multiple=0.1, linewidth=1, tick_formats="%.1f", offset=0.015, fontsize=fontsize, ) tax.clear_matplotlib_ticks() raw_ax = tax.get_axes() for spine in raw_ax.spines.values(): spine.set_visible(False) raw_ax.set_xticks([]) raw_ax.set_yticks([]) tax.set_title(title) def plot_demo_figure(df_stability: pd.DataFrame, df_lle: pd.DataFrame): import ternary labels = [ str(df_stability.iloc[0]["s1"]), str(df_stability.iloc[0]["s2"]), str(df_stability.iloc[0]["s3"]), ] fig, axes = plt.subplots(1, 2, figsize=(13, 5.5)) _, tax_stab = ternary.figure(scale=1.0, ax=axes[0]) configure_tax(tax_stab, labels, "STABILITY: stable vs unstable grid points") stable_df = df_stability.loc[~df_stability["unstable"].fillna(False).astype(bool)].copy() unstable_df = df_stability.loc[df_stability["unstable"].fillna(False).astype(bool)].copy() stable_points = _collect_feed_points(stable_df) unstable_points = _collect_feed_points(unstable_df) if stable_points: tax_stab.scatter( stable_points, s=42, marker="o", facecolors="none", edgecolors="#1b5e20", linewidths=1.2, label="Stable feed", ) if unstable_points: tax_stab.scatter( unstable_points, s=42, marker="o", facecolors="none", edgecolors="#c04a00", linewidths=1.2, label="Unstable feed", ) _, tax_lle = ternary.figure(scale=1.0, ax=axes[1]) configure_tax(tax_lle, labels, "LLE: phase boundary") lle_plot_df = df_lle.loc[ df_lle["converged"].fillna(False).astype(bool) & ~df_lle["llle_detected"].fillna(False).astype(bool) ].copy() if lle_plot_df.empty: raise ValueError("No converged non-LLLE LLE rows are available for this ternary system.") feed_points = _collect_feed_points(lle_plot_df) phase_i_points = _collect_phase_points(lle_plot_df, "xI") phase_ii_points = _collect_phase_points(lle_plot_df, "xII") tie_lines = _collect_tie_lines(lle_plot_df) tax_lle.scatter( feed_points, s=44, marker="^", facecolors="none", edgecolors="black", linewidths=1.0, label="Feed", ) tax_lle.scatter( phase_i_points, s=34, marker="o", facecolors="none", edgecolors="tab:blue", linewidths=1.2, label="Phase I boundary", ) tax_lle.scatter( phase_ii_points, s=34, marker="o", facecolors="none", edgecolors="tab:orange", linewidths=1.2, label="Phase II boundary", ) for index, (point_i, point_ii) in enumerate(tie_lines): tax_lle.line( point_i, point_ii, color="0.65", linewidth=0.9, alpha=0.85, label="Tie line" if index == 0 else None, ) for tax in (tax_stab, tax_lle): handles, _ = tax.get_axes().get_legend_handles_labels() if handles: tax.legend(loc="upper right", frameon=False) if hasattr(tax, "_redraw_labels"): tax._redraw_labels() fig.tight_layout() return fig