Source code for scm.glompo.core.function
""" Defines the minimization task API. """
__all__ = ("BaseFunction",)
from abc import abstractmethod
from pathlib import Path
from typing import Any, Dict, Sequence, Union
import tables as tb
[docs]class BaseFunction:
"""Template class for the optimization task GloMPO supports.
Direct use of this class for the minimization task is *not required*. However, this class does define the
minimum and optional API, and can be helpful in creating a cost function.
"""
[docs] @abstractmethod
def __call__(self, x: Sequence[float]) -> float:
"""Minimum cost function requirement.
Accepts a parameter vector ``x`` and returns the function evaluation's result.
:Parameters:
x
Vector in parameter space at which to evaluate the function.
:Returns:
float
Result of the function evaluation which is trying to be minimized.
"""
[docs] def detailed_call(self, x: Sequence[float]) -> Sequence[Any]:
"""Optional function evaluation method.
When called with a parameter vector (``x``) it returns a sequence of data. The first element of this sequence is
expected to be the function evaluation result (as returned by :meth:`__call__`). Subsequent elements of the
sequence may take any form. This function may be used to return information needed by the optimizer algorithm or
extra information which will be added to the log.
:Parameters:
x
Vector in parameter space at which to evaluate the function.
:Returns:
float
Result of the function evaluation which is trying to be minimized.
args
Additional returns of any type and length.
"""
raise NotImplementedError
[docs] def headers(self) -> Dict[str, tb.Col]:
"""Optional implementation.
If :meth:`detailed_call` is being used, this method returns a dictionary descriptor for each *extra* column of
the return. Keys represent the name of each *extra* element of the return, values represent the corresponding
`tables.Col <https://www.pytables.org/usersguide/libref/structured_storage.html#tables.Table.col>`_ data type.
If headers is not defined, GloMPO will attempt to infer types from a function evaluation return. Be warned that
this is a risky approach as incorrect inferences could be made. Numerical data types are also set to the largest
possible type (i.e. ``float64``) and strings are limited to 280 characters. This may lead to inefficient use
of space or data being truncated. If :meth:`detailed_call` is being used, implementation of headers is strongly
recommended.
:Returns:
Dict[str, tables.Col]
Mapping of heading names to the ``tables.Col`` type which indicates the type of
data the column of information will store.
:Examples:
>>> import tables
>>> header = {'training_set_residuals': tables.Float64Col(shape=100, pos=0),
... 'validation_set_fx': tables.Float64Col(pos=1),
... 'errors': tables.StringCol(itemsize=280, dflt=b'None', pos=2)}
"""
raise NotImplementedError
[docs] def checkpoint_save(self, path: Union[str, Path]):
"""Persists the function into a file or files from which it can be reconstructed.
This method is used when a checkpoint of the manager is made and the function cannot be persisted directly.
A checkpoint is a compressed directory of files which persists all aspects of an in-progress optimization.
These checkpoints can be loaded by :class:`.GloMPOManager` and the optimization resumed.
Implementing this function is optional and only required if directly pickling the function is not possible.
In order to load a checkpoint in which :meth:`checkpoint_save` was used, see
:meth:`.GloMPOManager.load_checkpoint`).
:Parameters:
path
:obj:`str` or :class:`python:pathlib.Path` to a directory into which files will be saved.
"""
raise NotImplementedError
[docs] @classmethod
def checkpoint_load(cls, path: Union[str, Path]):
"""Creates an instance of the :class:`BaseFunction` from sources.
These source are the products of :meth:`checkpoint_save`. In order to use this method, it should be sent to the
``task_loader`` argument of :meth:`.GloMPOManager.load_checkpoint`.
:Parameters:
path
:obj:`str` or :class:`~python:pathlib.Path` to a directory which contains the files which will be loaded.
"""
raise NotImplementedError