import warnings
from collections import deque
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from multiprocessing.connection import Connection
from pathlib import Path
from threading import Event
from typing import Callable, Optional, Sequence, Set, Tuple, Union
from queue import Queue
import nevergrad as ng
import numpy as np
from .baseoptimizer import BaseOptimizer, MinimizeResult
from ...plams.core.settings import Settings
__all__ = ("Nevergrad",)
if tuple(ng.__version__.split(".")) >= ("0", "4"):
warnings.warn(
"A known bug in Nevergrad >= 0.4 will break the python logging system. "
"Please downgrade to 0.3.2 to regain control of the logging system.",
ImportWarning,
)
class _ClippedInt(int):
"""Special class used to hack Nevergrad.
In order to get Nevergrad to cleanly exit we will trick it into thinking that the budget has been all used up.
However, our hack can lead to the remaining budget ending up negative within the Nevergrad code.
Unfortunately, the NG loop only looks for a remaining budget of exactly 0 rather than any number below zero.
Thus, this class (used for the NG budget keyword) will return the difference of numbers or zero, whichever is
larger.
"""
def __new__(cls, budget):
obj = int.__new__(cls, budget)
obj._value = int(budget) # Must be cast to int otherwise problems arise in NG code
return obj
def __sub__(self, other):
return max(0, self._value - other)
[docs]class Nevergrad(BaseOptimizer):
"""Provides access to the optimizers available through the
`nevergrad <https://facebookresearch.github.io/nevergrad/>`_ package.
Tested with v.0.3.2.
:Parameters:
_opt_id _signal_pipe _results_queue _pause_flag _is_log_detailed _workers _backend
See :class:`.BaseOptimizer`.
optimizer
String key to the desired optimizer. See nevergrad documentation for a list of available algorithms.
zero
Will stop the optimization when this cost function value is reached.
warn
If ``True``, suppresses all nevergrad warnings.
``**kwargs``
Extra initialisation arguments passed to the optimizer.
:Notes:
.. important::
It is crucial that you check for possible search space settings of the requested optimizer and adjust them
through the ``kwargs`` argument accordingly. By default, ParAMS will return an infinite loss function value
if any of the parameters is out of bounds. Certain algorithms (especially if the dimensionality is high) might
not be able to adjust to the bounded search space, resulting in all evaluations producing ``float(inf)``,
and a failed optimization.
"""
_scaler = "std"
def __init__(
self,
_opt_id: int = None,
_signal_pipe: Connection = None,
_results_queue: Queue = None,
_pause_flag: Event = None,
_is_log_detailed: bool = False,
_workers: int = 1,
_backend: str = "threads",
optimizer: str = "TBPSA",
zero: float = 0,
warn=False,
**kwargs,
):
super().__init__(
_opt_id,
_signal_pipe,
_results_queue,
_pause_flag,
_is_log_detailed,
_workers,
_backend,
optimizer=optimizer,
zero=zero,
warn=warn,
**kwargs,
)
self.opt_algo = ng.optimizers.registry[optimizer]
self.opt_algo_name = optimizer # Very difficult to work back to name from opt_algo easier to just save name
self.optimizer = None
if self.opt_algo.no_parallelization is True:
warnings.warn(
"The selected algorithm does not support parallel execution, workers overwritten and set to" " one.",
RuntimeWarning,
)
self.workers = 1
self.zero = zero
self.stop = False
self._ng_callbacks = None
if "budget" not in kwargs:
kwargs["budget"] = 1e10
kwargs["budget"] = _ClippedInt(kwargs["budget"])
self.opt_init_kwargs = kwargs
if warn is False:
warnings.filterwarnings("ignore", module="nevergrad")
def __amssettings__(self, s: Settings) -> Settings:
s.input.ams.Optimizer.Type = "Nevergrad"
s.input.ams.Optimizer.Nevergrad.Algorithm = self.opt_algo_name
s.input.ams.Optimizer.Nevergrad.Zero = self.zero
for k, v in self.opt_init_kwargs.items():
s.input.ams.Optimizer.Nevergrad.Settings[k] = v
return s
def minimize(
self, function: Callable[[Sequence[float]], float], x0: Sequence[float], bounds: Sequence[Tuple[float, float]]
) -> MinimizeResult:
lower, upper = np.transpose(bounds)
parametrization = ng.p.Array(init=x0)
parametrization.set_bounds(lower, upper)
if not self.is_restart:
self.optimizer = self.opt_algo(
parametrization=parametrization, num_workers=self.workers, **self.opt_init_kwargs
)
self.stop = False
self.optimizer.register_callback("tell", self._glompo_control_callback)
if self.workers > 1 and self._backend == "processes":
executor = ProcessPoolExecutor(max_workers=self.workers)
elif self.workers > 1 and self._backend == "threads":
executor = ThreadPoolExecutor(max_workers=self.workers)
else:
executor = None
opt_vec = self.optimizer.minimize(function, executor=executor, batch_mode=True, verbosity=1)
if executor:
executor.shutdown()
results = MinimizeResult()
results.x = opt_vec.value
results.fx = function(results.x)
if results.fx < float("inf"):
results.success = True
return results
def callstop(self, *args):
self.stop = True
def checkpoint_save(
self, path: Union[Path, str], force: Optional[Set[str]] = None, block: Optional[Set[str]] = None
):
# todo test/fix all possible algorithms.
# If passed, add to comments below. If failed, try fix or add to not_supported below
# Tested & Passed:
# CMA
# MEDA
# TPBSA
not_supported = {"RSQP"}
if self.opt_algo_name in not_supported:
raise NotImplementedError(f"Checkpointing cannot be supported for '{self.opt_algo_name}' algorithm.")
tmp_running_jobs = self.optimizer._running_jobs
tmp_finished_jobs = self.optimizer._finished_jobs
self.optimizer._running_jobs = []
self.optimizer._finished_jobs = deque()
self.optimizer._callbacks = {}
super().checkpoint_save(path, force={"optimizer"})
self.optimizer._running_jobs = tmp_running_jobs
self.optimizer._finished_jobs = tmp_finished_jobs
self.optimizer._callbacks = {"tell": [self._glompo_control_callback]}
def _glompo_control_callback(self, opt: ng.optimizers.base.Optimizer, x: ng.p.Array, fx: float):
"""Wraps all the components needed by GloMPO to be called after each iteration into a single object which
can be registered as a nevergrad callback.
"""
if not self.stop:
stop_cond = None
# Normal termination condition
if fx >= 1e30 or fx <= self.zero:
stop_cond = (
f"Nevergrad termination conditions:\n"
f"(fx >= 1e30) = {fx >= 1e30}\n"
f"(fx <= {self.zero}) = {fx <= self.zero}"
)
self.logger.debug("Stop = %s at convergence condition", bool(stop_cond))
# GloMPO specific callbacks
if self._results_queue:
self._pause_signal.wait()
self.check_messages()
if not stop_cond and self.stop:
stop_cond = "GloMPO termination signal."
if stop_cond:
self.logger.debug("Stop is True so shutting down optimizer.")
self.stop = True
opt._num_ask = opt.budget - 1 # This is the hack which tricks Nevergrad into early convergence
if self._results_queue:
self.message_manager(0, stop_cond)