from abc import abstractmethod
from concurrent.futures import ThreadPoolExecutor
from multiprocessing.connection import Connection
from pathlib import Path
from queue import Queue
from threading import Event, Lock
from typing import Callable, Iterable, Optional, Sequence, Tuple, Union
import numpy as np
from .baseoptimizer import BaseOptimizer, MinimizeResult
from ...plams.core.settings import Settings
__all__ = ("RandomOptimizer",)
class StopCalledException(Exception):
    """Custom exception used to stop iterators early"""
class _SamplingOptimizer(BaseOptimizer):
    """Superclass for Random and Grid Sampling 'optimizer' routines."""
    @classmethod
    def checkpoint_load(cls: "RandomOptimizer", path: Union[Path, str], **kwargs) -> "RandomOptimizer":
        opt = super().checkpoint_load(path, **kwargs)
        opt._generator_pause = Event()
        opt._generator_pause.set()
        return opt
    def __init__(
        self,
        _opt_id: int = None,
        _signal_pipe: Connection = None,
        _results_queue: Queue = None,
        _pause_flag: Event = None,
        _is_log_detailed: bool = False,
        _workers: int = 1,
        _backend: str = "threads",
        **kwargs,
    ):
        super().__init__(
            _opt_id, _signal_pipe, _results_queue, _pause_flag, _is_log_detailed, _workers, _backend, **kwargs
        )
        self.n_used = 0
        self.result = MinimizeResult()
        self.stop_called = False
        self._generator_pause = Event()
        self._generator_pause.set()
        self.logger.debug("Setup optimizer")
    def minimize(
        self, function: Callable[[Sequence[float]], float], x0: Sequence[float], bounds: Sequence[Tuple[float, float]]
    ) -> MinimizeResult:
        def evaluate(x):
            if self.stop_called:
                raise StopCalledException
            self._generator_pause.wait()
            fx = function(x)
            with lock:
                self.n_used += 1  # Used for checkpointing and restarts
                if fx < self.result.fx:
                    self.result.fx = fx
                    self.result.x = x
            return x, fx
        lock = Lock()
        mins, maxs = np.transpose(bounds)
        if self.workers > 1:
            if self._backend == "threads":
                executor = ThreadPoolExecutor(self.workers)
            else:
                # PLAMS only supports threads for now, but if processes were used then a better locking mechanism on
                # incrementing n_used would be needed and stop_called would also need to be a shared variable
                raise NotImplementedError
            mapper = executor.map
        else:
            mapper = map
        generator = mapper(evaluate, self._sample_set(mins, maxs))
        try:
            for _ in generator:
                if self._results_queue:
                    self.check_messages()
                    self._pause_signal.wait()
                if self.stop_called:
                    break
        except StopCalledException:
            pass
        self.result.success = not bool(self.stop_called)
        if self.workers > 1:
            executor.shutdown()
        if self._results_queue:
            self.logger.debug("Messaging manager")
            self.message_manager(0, "Optimizer convergence")
            self.check_messages()
        return self.result
    def callstop(self, reason: str = ""):
        self.stop_called = True
    @abstractmethod
    def _sample_set(self, mins: np.ndarray, maxs: np.ndarray) -> Iterable[np.ndarray]: ...
    def _prepare_checkpoint(self):
        if self.workers > 1:
            self._generator_pause.clear()
        super()._prepare_checkpoint()
        self._generator_pause.set()
[docs]class RandomOptimizer(_SamplingOptimizer):
    """Evaluates random points within the bounds for a fixed number of iterations.
    :Parameters:
    _opt_id, _signal_pipe, _results_queue, _pause_flag, _is_log_detailed, _workers, _backend
        See :class:`.BaseOptimizer`.
    iters
        Number of function evaluations the optimizer will execute before terminating.
    """
    def __init__(
        self,
        _opt_id: int = None,
        _signal_pipe: Connection = None,
        _results_queue: Queue = None,
        _pause_flag: Event = None,
        _is_log_detailed: bool = False,
        _workers: int = 1,
        _backend: str = "threads",
        iters: int = 100,
        seed: Optional[int] = None,
    ):
        super().__init__(
            _opt_id,
            _signal_pipe,
            _results_queue,
            _pause_flag,
            _is_log_detailed,
            _workers,
            _backend,
            iters=iters,
            seed=seed,
        )
        self.max_iters = iters
        self.seed = seed
    def __amssettings__(self, s: Settings) -> Settings:
        s.input.ams.Optimizer.Type = "RandomSampling"
        s.input.ams.Optimizer.RandomSampling.NumberOfSamples = self.max_iters
        if self.seed is not None:
            s.input.ams.Optimizer.RandomSampling.RandomSeed = self.seed
        return s
    def _sample_set(self, mins: np.ndarray, maxs: np.ndarray) -> Iterable[np.ndarray]:
        rand = np.random.RandomState(self.seed)
        for _ in range(self.n_used, self.max_iters):
            yield rand.uniform(mins, maxs)