import numpy as np
from .basestopper import BaseStopper
from ..core.optimizerlogger import BaseLogger
from ...plams.core.settings import Settings
__all__ = ("ValidationWorsening",)
[docs]class ValidationWorsening(BaseStopper):
"""Considers loss function values of the validation set.
Triggers if the value of the last evaluation of the validation set cost function is worse than the average of the
last several calls.
:Parameters:
calls
Number of function evaluations between comparison points. Increase this number to capture long-term increases in
the loss value change rather than sudden spurious fluctuations.
tol
Tolerance fraction applied to the loss values made ``calls`` prior. The loss value (f) is modified by:
f * (1 + ``tol``). This can also be used to ignore the effects of mild fluctuations.
:Returns:
bool
Returns ``True`` if the loss value of the last validation set evaluation is worse than the average of the
``calls`` evaluations before.
"""
def __init__(self, calls: int = 1, tol: float = 0):
super().__init__()
self.calls = calls
self.ix = calls + 1
self.tol = tol
def __call__(self, log: BaseLogger, best_opt_id: int, tested_opt_id: int) -> bool:
loss = log.get_history(tested_opt_id, "validation_set")
loss = loss[~np.isnan(loss)] # Take only the evaluated values
n_calls = loss.size
if n_calls < self.ix:
# If there are insufficient iterations the stopper will return False
self.last_result = False
return self.last_result
self.last_result = loss[-1] > loss[-self.ix :].mean() * (1 + self.tol)
return self.last_result
def __amssettings__(self, s: Settings) -> Settings:
s.input.ams.Stopper.Type = "ValidationWorsening"
s.input.ams.Stopper.ValidationWorsening.NumberOfFunctionCalls = self.calls
s.input.ams.Stopper.ValidationWorsening.Tolerance = self.tol
return s