Source code for scm.glompo.stoppers.validationworsening

import numpy as np

from .basestopper import BaseStopper
from ..core.optimizerlogger import BaseLogger
from ...plams.core.settings import Settings

__all__ = ("ValidationWorsening",)


[docs]class ValidationWorsening(BaseStopper): """Considers loss function values of the validation set. Triggers if the value of the last evaluation of the validation set cost function is worse than the average of the last several calls. :Parameters: calls Number of function evaluations between comparison points. Increase this number to capture long-term increases in the loss value change rather than sudden spurious fluctuations. tol Tolerance fraction applied to the loss values made ``calls`` prior. The loss value (f) is modified by: f * (1 + ``tol``). This can also be used to ignore the effects of mild fluctuations. :Returns: bool Returns ``True`` if the loss value of the last validation set evaluation is worse than the average of the ``calls`` evaluations before. """ def __init__(self, calls: int = 1, tol: float = 0): super().__init__() self.calls = calls self.ix = calls + 1 self.tol = tol def __call__(self, log: BaseLogger, best_opt_id: int, tested_opt_id: int) -> bool: loss = log.get_history(tested_opt_id, "validation_set") loss = loss[~np.isnan(loss)] # Take only the evaluated values n_calls = loss.size if n_calls < self.ix: # If there are insufficient iterations the stopper will return False self.last_result = False return self.last_result self.last_result = loss[-1] > loss[-self.ix :].mean() * (1 + self.tol) return self.last_result def __amssettings__(self, s: Settings) -> Settings: s.input.ams.Stopper.Type = "ValidationWorsening" s.input.ams.Stopper.ValidationWorsening.NumberOfFunctionCalls = self.calls s.input.ams.Stopper.ValidationWorsening.Tolerance = self.tol return s