Source code for scm.glompo.opt_selectors.chain

from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

from .baseselector import BaseSelector
from .spawncontrol import BaseController
from ..core.optimizerlogger import BaseLogger
from ..optimizers.baseoptimizer import BaseOptimizer
from ...plams.core.settings import Settings

__all__ = ("ChainSelector",)


[docs]class ChainSelector(BaseSelector): """Selects the type of optimizer to start based on the number of function evaluations already used. Designed to start different types of optimizers at different stages of the optimization. Selects sequentially from the list of available optimizers based on the number of function evaluations used. :Parameters: ``*avail_opts`` See :class:`.BaseSelector`. fcall_thresholds A list of length ``n-1``, where ``n`` is the length of ``avail_opts``. Each element indicates the function evaluation point at which the selector switches to the next type of optimizer in ``avail_opts``. allow_spawn See :class:`.BaseSelector`. :Examples: >>> ChainSelector(OptimizerA, OptimizerB, fcall_thresholds=[1000]) In this case ``OptimizerA`` instances will be started in the first 1000 iterations and ``OptimizerB`` instances will be started thereafter. """ def __init__( self, *avail_opts: Union[BaseOptimizer, Type[BaseOptimizer], Tuple[Type[BaseOptimizer], Optional[Dict[str, Any]]]], fcall_thresholds: List[float], allow_spawn: Optional[List[BaseController]] = None, ): super().__init__(*avail_opts, allow_spawn=allow_spawn) self.fcall_thresholds = fcall_thresholds n = len(avail_opts) assert len(fcall_thresholds) == n - 1, "Must be one threshold less than available optimizers" self.toggle = 0 def __amssettings__(self, s: Settings) -> Settings: s = self._spawnersettings(s) s.input.ams.OptimizerSelector.Type = "Chain" s.input.ams.OptimizerSelector.Chain.Thresholds = " ".join((str(n) for n in self.fcall_thresholds)) return s def select_optimizer( self, manager: "GloMPOManager", slots_available: int ) -> Union[Tuple[Type[BaseOptimizer], Dict[str, Any], Dict[str, Any]], None, bool]: if not all((spawner(manager) for spawner in self.allow_spawn)): return False if self.toggle < len(self.fcall_thresholds) and manager.f_counter >= self.fcall_thresholds[self.toggle]: self.toggle += 1 selected = self.avail_opts[self.toggle] if selected[1]["_workers"] > slots_available: return None return selected