Source code for scm.glompo.generators.peterbation

from typing import Sequence

import numpy as np
from scipy.stats import truncnorm

from .basegenerator import BaseGenerator
from ...plams.core.settings import Settings

__all__ = ("PerturbationGenerator",)


[docs]class PerturbationGenerator(BaseGenerator): """Randomly generates parameter vectors near a given point. Draws samples from a truncated multivariate normal distributed centered around a provided vector and bound by given bounds. Good for parametrisation efforts where a good candidate is already available, however, this may drastically limit the exploratory nature of GloMPO. :Parameters: x0 Center point for each parameter scale Standard deviation of each parameter. Used here to control how wide the generator should explore around the mean. """ def __init__(self, x0: Sequence[float], scale: Sequence[float]): super().__init__() self.loc = np.array(x0) self.scale = np.array(scale) def __amssettings__(self, s: Settings) -> Settings: s.input.ams.Generator.Type = "Perturbation" s.input.ams.Generator.Perturbation.StandardDeviation = self.scale return s def generate(self, manager: "GloMPOManager") -> np.ndarray: lb = np.array(manager.bounds)[:, 0] ub = np.array(manager.bounds)[:, 1] a = (lb - self.loc) / self.scale b = (ub - self.loc) / self.scale x0 = truncnorm.rvs(a, b, self.loc, self.scale) return x0