import re
from copy import copy
from datetime import datetime
from pathlib import Path
from typing import Union
from ...plams.core.settings import Settings
__all__ = ("CheckpointingControl",)
[docs]class CheckpointingControl:
"""Class to set up and control the checkpointing behaviour of the :class:`.GloMPOManager`.
This class has limited functionality and is mainly a container for various settings. The initialisation arguments
match the class attributes of the same name.
:Attributes:
checkpoint_at_conv : bool
If ``True`` a checkpoint is built just before the manager exists.
checkpoint_at_init : bool
If ``True`` a checkpoint is built at the very start of the optimization. This can make starting duplicate
jobs easier.
checkpoint_iter_interval : float
Number of function evaluations between checkpoints being saved to disk during an
optimization. Function call based checkpointing not performed if this parameter is not provided.
checkpoint_time_interval : float
Number of seconds between checkpoints being saved to disk during an optimization. Time based
checkpointing not performed if this parameter is not provided.
checkpointing_dir : Union[pathlib.Path, str]
Directory in which checkpoints are saved. Defaults to ``'checkpoints'``
.. important::
This path is always converted to an absolute path, if a relative path is provided it will be relative to the
current working directory when this object is created. There is no relation to
:class:`GloMPOManager.working_dir <.GloMPOManager>`.
count : int
Counter for checkpoint naming patterns which rely on incrementing filenames.
Count starts from the largest existing match in ``checkpointing_dir`` or zero otherwise. Formatted to 3 digits.
force_task_save : bool
Some tasks may pickle successfully but fail to load properly, if this is an issue then setting this
parameter to ``True`` will cause the manager to bypass the pickle task step and immediately attempt the
:meth:`~.BaseFunction.checkpoint_save` method.
keep_past : int
The number of checkpoints retained when a new checkpoint is made. Any older ones are deleted.
Default is -1 which performs no deletion. ``keep_past = 0`` retains no previous results, only the newly
constructed checkpoint will exist.
.. note::
#. GloMPO will only count the directories in ``checkpointing_dir`` and matching the supplied
``naming_format``.
#. Existing checkpoints will only be deleted if the new checkpoint is successfully constructed.
naming_format : str
Convention used to name the checkpoints.
Special keys that can be used:
=================== ======================
Naming Format Key Checkpoint Name Result
=================== ======================
``'%(date)'`` Current calendar date in YYYYMMDD format
``'%(year)'`` Year formatted to YYYY
``'%(yr)'`` Year formatted to YY
``'%(month)'`` Numerical month formatted to MM
``'%(day)'`` Calendar day of the month formatted to DD
``'%(time)'`` Current calendar time formatted to HHMMSS (24-hour style)
``'%(hour)'`` Hour formatted to HH (24-hour style)
``'%(min)'`` Minutes formatted to MM
``'%(sec)'`` Seconds formatted to SS
``'%(count)'`` Index count of the number of checkpoints constructed.
=================== ======================
raise_checkpoint_fail : bool
If ``True`` a failed checkpoint will cause the manager to end the optimization in error. Note, that GloMPO
will always write out some data when it terminates. This can be a way of preserving data if the checkpoint
fails. If ``False`` an error in constructing a checkpoint will simply raise a warning and pass.
"""
def __init__(
self,
checkpoint_time_interval: float = float("inf"),
checkpoint_iter_interval: float = float("inf"),
checkpoint_at_init: bool = False,
checkpoint_at_conv: bool = False,
raise_checkpoint_fail: bool = False,
force_task_save: bool = False,
keep_past: int = -1,
naming_format: str = "glompo_checkpoint_%(date)_%(time)",
checkpointing_dir: Union[Path, str] = "checkpoints",
):
self.checkpoint_time_interval = checkpoint_time_interval
self.checkpoint_iter_interval = checkpoint_iter_interval
self.checkpoint_at_init = checkpoint_at_init
self.checkpoint_at_conv = checkpoint_at_conv
self.checkpointing_dir = Path(checkpointing_dir).resolve()
self.raise_checkpoint_fail = bool(raise_checkpoint_fail)
self.force_task_save = bool(force_task_save)
self.keep_past = keep_past
self.naming_format = naming_format
self.count = None
codes = {
"%[(]date[)]": 8,
"%[(]year[)]": 4,
"%[(]yr[)]": 2,
"%[(]month[)]": 2,
"%[(]day[)]": 2,
"%[(]time[)]": 6,
"%[(]hour[)]": 2,
"%[(]min[)]": 2,
"%[(]sec[)]": 2,
}
format_re = list(copy(self.naming_format))
for i, char in enumerate(format_re):
if any([char == c for c in ("{", "(", "+", "*", "|", ".", "$", ")", "}")]):
format_re[i] = f"[{char}]"
if any([char == c for c in ("^", "[", "]")]):
format_re[i] = rf"\{char}"
format_re = "".join(format_re)
for key, digits in codes.items():
format_re = format_re.replace(key, f"[0-9]{{{digits}}}")
format_re = format_re.replace("%[(]count[)]", "(?P<index>[0-9]{3})")
self._naming_format_re = format_re
def __amssettings__(self, s: Settings) -> Settings:
iter_freq = self.checkpoint_iter_interval if self.checkpoint_iter_interval < float("inf") else -1
time_freq = self.checkpoint_time_interval if self.checkpoint_time_interval < float("inf") else -1
s.input.ams.CheckpointControl.AtEnd = self.checkpoint_at_conv
s.input.ams.CheckpointControl.AtInitialisation = self.checkpoint_at_init
s.input.ams.CheckpointControl.EveryFunctionCalls = iter_freq
s.input.ams.CheckpointControl.EverySeconds = time_freq
s.input.ams.CheckpointControl.RaiseFail = self.raise_checkpoint_fail
s.input.ams.CheckpointControl.KeepPast = self.keep_past
s.input.ams.CheckpointControl.NamingFormat = self.naming_format
s.input.ams.CheckpointControl.CheckpointingDirectory = self.checkpointing_dir
return s
@property
def any_true(self) -> bool:
"""Returns ``True`` if at least one of the four checkpointing conditions is set to produce checkpoints."""
return (
self.checkpoint_at_init
or self.checkpoint_at_conv
or self.checkpoint_iter_interval < float("inf")
or self.checkpoint_time_interval < float("inf")
)
[docs] def get_name(self) -> str:
"""Returns a new name for a checkpoint matching the naming format."""
time = datetime.now()
name = copy(self.naming_format)
codes = {
"%(date)": "%Y%m%d",
"%(year)": "%Y",
"%(yr)": "%y",
"%(month)": "%m",
"%(day)": "%d",
"%(time)": "%H%M%S",
"%(hour)": "%H",
"%(min)": "%M",
"%(sec)": "%S",
}
for key, val in codes.items():
name = name.replace(key, time.strftime(val))
if self.checkpointing_dir.exists():
max_index = -1
matches = [re.match(self._naming_format_re, folder.name) for folder in self.checkpointing_dir.iterdir()]
for match in matches:
if match and match.lastgroup == "index":
i = int(match.group("index"))
max_index = i if i > max_index else max_index
self.count = max_index + 1
else:
self.count = 0
name = name.replace("%(count)", f"{self.count:03}")
self.count += 1
return name