from itertools import islice, product
import numpy as np
from multiprocessing.connection import Connection
from threading import Event
from typing import Iterable, Optional
from .random import _SamplingOptimizer
from ..core._backends import ChunkingQueue
from ...plams.core.settings import Settings
__all__ = ("SimpleGridOptimizer",)
[docs]class SimpleGridOptimizer(_SamplingOptimizer):
"""A simple example optimizer, that just evaluates the fitness function on a regular grid in parameter space."""
def __init__(
self,
_opt_id: Optional[int] = None,
_signal_pipe: Optional[Connection] = None,
_results_queue: Optional[ChunkingQueue] = None,
_pause_flag: Optional[Event] = None,
_is_log_detailed: bool = False,
_workers: int = 1,
_backend: str = "threads",
nsteps=10,
):
"""Create a new optimizer for a regular grid of ``nsteps`` points in every active parameter's range."""
super().__init__(
_opt_id, _signal_pipe, _results_queue, _pause_flag, _is_log_detailed, _workers, _backend, nsteps=nsteps
)
self.nsteps = nsteps
def __amssettings__(self, s: Settings) -> Settings:
s.input.ams.Optimizer.Type = "GridSampling"
s.input.ams.Optimizer.GridSampling.NumberOfDivisions = self.nsteps
return s
def _sample_set(self, mins: np.ndarray, maxs: np.ndarray) -> Iterable[np.ndarray]:
pargrid = np.linspace(mins, maxs, self.nsteps).T
for x in islice(product(*pargrid), self.n_used, None):
yield x