Source code for scm.glompo.optimizers.nevergrad

import warnings
from collections import deque
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from multiprocessing.connection import Connection
from pathlib import Path
from threading import Event
from typing import Callable, Optional, Sequence, Set, Tuple, Union
from queue import Queue

import nevergrad as ng
import numpy as np

from .baseoptimizer import BaseOptimizer, MinimizeResult
from ...plams.core.settings import Settings

__all__ = ("Nevergrad",)


if tuple(ng.__version__.split(".")) >= ("0", "4"):
    warnings.warn(
        "A known bug in Nevergrad >= 0.4 will break the python logging system. "
        "Please downgrade to 0.3.2 to regain control of the logging system.",
        ImportWarning,
    )


class _ClippedInt(int):
    """Special class used to hack Nevergrad.
    In order to get Nevergrad to cleanly exit we will trick it into thinking that the budget has been all used up.
    However, our hack can lead to the remaining budget ending up negative within the Nevergrad code.
    Unfortunately, the NG loop only looks for a remaining budget of exactly 0 rather than any number below zero.
    Thus, this class (used for the NG budget keyword) will return the difference of numbers or zero, whichever is
    larger.
    """

    def __new__(cls, budget):
        obj = int.__new__(cls, budget)
        obj._value = int(budget)  # Must be cast to int otherwise problems arise in NG code
        return obj

    def __sub__(self, other):
        return max(0, self._value - other)


[docs]class Nevergrad(BaseOptimizer): """Provides access to the optimizers available through the `nevergrad <https://facebookresearch.github.io/nevergrad/>`_ package. Tested with v.0.3.2. :Parameters: _opt_id _signal_pipe _results_queue _pause_flag _is_log_detailed _workers _backend See :class:`.BaseOptimizer`. optimizer String key to the desired optimizer. See nevergrad documentation for a list of available algorithms. zero Will stop the optimization when this cost function value is reached. warn If ``True``, suppresses all nevergrad warnings. ``**kwargs`` Extra initialisation arguments passed to the optimizer. :Notes: .. important:: It is crucial that you check for possible search space settings of the requested optimizer and adjust them through the ``kwargs`` argument accordingly. By default, ParAMS will return an infinite loss function value if any of the parameters is out of bounds. Certain algorithms (especially if the dimensionality is high) might not be able to adjust to the bounded search space, resulting in all evaluations producing ``float(inf)``, and a failed optimization. """ _scaler = "std" def __init__( self, _opt_id: int = None, _signal_pipe: Connection = None, _results_queue: Queue = None, _pause_flag: Event = None, _is_log_detailed: bool = False, _workers: int = 1, _backend: str = "threads", optimizer: str = "TBPSA", zero: float = 0, warn=False, **kwargs, ): super().__init__( _opt_id, _signal_pipe, _results_queue, _pause_flag, _is_log_detailed, _workers, _backend, optimizer=optimizer, zero=zero, warn=warn, **kwargs, ) self.opt_algo = ng.optimizers.registry[optimizer] self.opt_algo_name = optimizer # Very difficult to work back to name from opt_algo easier to just save name self.optimizer = None if self.opt_algo.no_parallelization is True: warnings.warn( "The selected algorithm does not support parallel execution, workers overwritten and set to" " one.", RuntimeWarning, ) self.workers = 1 self.zero = zero self.stop = False self._ng_callbacks = None if "budget" not in kwargs: kwargs["budget"] = 1e10 kwargs["budget"] = _ClippedInt(kwargs["budget"]) self.opt_init_kwargs = kwargs if warn is False: warnings.filterwarnings("ignore", module="nevergrad") def __amssettings__(self, s: Settings) -> Settings: s.input.ams.Optimizer.Type = "Nevergrad" s.input.ams.Optimizer.Nevergrad.Algorithm = self.opt_algo_name s.input.ams.Optimizer.Nevergrad.Zero = self.zero for k, v in self.opt_init_kwargs.items(): s.input.ams.Optimizer.Nevergrad.Settings[k] = v return s def minimize( self, function: Callable[[Sequence[float]], float], x0: Sequence[float], bounds: Sequence[Tuple[float, float]] ) -> MinimizeResult: lower, upper = np.transpose(bounds) parametrization = ng.p.Array(init=x0) parametrization.set_bounds(lower, upper) if not self.is_restart: self.optimizer = self.opt_algo( parametrization=parametrization, num_workers=self.workers, **self.opt_init_kwargs ) self.stop = False self.optimizer.register_callback("tell", self._glompo_control_callback) if self.workers > 1 and self._backend == "processes": executor = ProcessPoolExecutor(max_workers=self.workers) elif self.workers > 1 and self._backend == "threads": executor = ThreadPoolExecutor(max_workers=self.workers) else: executor = None opt_vec = self.optimizer.minimize(function, executor=executor, batch_mode=True, verbosity=1) if executor: executor.shutdown() results = MinimizeResult() results.x = opt_vec.value results.fx = function(results.x) if results.fx < float("inf"): results.success = True return results def callstop(self, *args): self.stop = True def checkpoint_save( self, path: Union[Path, str], force: Optional[Set[str]] = None, block: Optional[Set[str]] = None ): # todo test/fix all possible algorithms. # If passed, add to comments below. If failed, try fix or add to not_supported below # Tested & Passed: # CMA # MEDA # TPBSA not_supported = {"RSQP"} if self.opt_algo_name in not_supported: raise NotImplementedError(f"Checkpointing cannot be supported for '{self.opt_algo_name}' algorithm.") tmp_running_jobs = self.optimizer._running_jobs tmp_finished_jobs = self.optimizer._finished_jobs self.optimizer._running_jobs = [] self.optimizer._finished_jobs = deque() self.optimizer._callbacks = {} super().checkpoint_save(path, force={"optimizer"}) self.optimizer._running_jobs = tmp_running_jobs self.optimizer._finished_jobs = tmp_finished_jobs self.optimizer._callbacks = {"tell": [self._glompo_control_callback]} def _glompo_control_callback(self, opt: ng.optimizers.base.Optimizer, x: ng.p.Array, fx: float): """Wraps all the components needed by GloMPO to be called after each iteration into a single object which can be registered as a nevergrad callback. """ if not self.stop: stop_cond = None # Normal termination condition if fx >= 1e30 or fx <= self.zero: stop_cond = ( f"Nevergrad termination conditions:\n" f"(fx >= 1e30) = {fx >= 1e30}\n" f"(fx <= {self.zero}) = {fx <= self.zero}" ) self.logger.debug("Stop = %s at convergence condition", bool(stop_cond)) # GloMPO specific callbacks if self._results_queue: self._pause_signal.wait() self.check_messages() if not stop_cond and self.stop: stop_cond = "GloMPO termination signal." if stop_cond: self.logger.debug("Stop is True so shutting down optimizer.") self.stop = True opt._num_ask = opt.budget - 1 # This is the hack which tricks Nevergrad into early convergence if self._results_queue: self.message_manager(0, stop_cond)