# Source code for scm.glompo.analysis.hsic

import os
import sys

import psutil
import tempfile as tf
import warnings
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from .kernels import BaseKernel, ConjunctiveGaussianKernel, GaussianKernel

__all__ = ("HSIC", "HSICResult")

[docs]class HSICResult:
"""Result object of a single HSIC calculation produced by :meth:HSIC.compute.
Can only be created by :meth:.HSIC.compute or :meth:load.
"""

@property
def hsic(self) -> np.ndarray:
""":math:d vector of raw HSIC calculation results. If :attr:n_bootstraps is larger than one, this
represents the mean over the bootstraps.

:attr:hsic_std
:attr:sensitivities
"""
return self._hsic.mean(0)

@property
def hsic_std(self) -> Optional[np.ndarray]:
""":math:d vector of the standard deviation of raw HSIC calculation results across all the
bootstraps.

:attr:hsic
:attr:sensitivities_std
"""
return self._hsic.std(0)

@property
def sensitivities(self) -> np.ndarray:
r""":math:d vector of normalised HSIC calculation results.
These are the main sensitivity results a user should be concerned with.
If :attr:n_bootstraps is larger than one. This represents the mean over the bootstraps.

Sensitivities are defined such that:

.. math::

\sum S_d = 1

.. math ::

0 \leq S_d \leq 1

:attr:hsic
:attr:sensitivities_std
"""
return self._s_mean

@property
def sensitivities_std(self) -> np.ndarray:
""":math:d vector of the standard deviation of normalised HSIC calculation results across all the bootstraps.

:attr:hsic_std
:attr:sensitivities
"""
return self._s.std(0)

@property
def n_bootstraps(self) -> int:
"""Number of times the HSIC calculation was performed with different sub-samples of the data."""
return self._n_boot

@property
def n_samples(self) -> int:
"""Number of items in the sub-samples of the data used in each bootstrap."""
return self._n_sample

@property
def n_factors(self) -> int:
"""Number of factors analyzed."""
return self._d

@property
def sampling_with_replacement(self) -> bool:
"""Whether sampling was done with or without replacement from the available data.
False, if all available data was used and automatically True if more samples were used than the number available
in the data set.
"""
return self._replace

@property
def inputs_kernel(self) -> str:
"""Name of the kernel applied to the input-space data."""
return self._in_kernel

@property
def inputs_kernel_parameters(self) -> Dict[str, Any]:
"""Parameters for the input kernel."""
return self._in_kernel_params

@property
def outputs_kernel(self) -> str:
"""Name of the kernel applied to the output-space data."""
return self._out_kernel

@property
def outputs_kernel_parameters(self) -> Dict[str, Any]:
"""Parameters for the output kernel."""
return self._out_kernel_params

@property
def order_factors(self) -> np.ndarray:
"""Returns factor indices in descending order of their influence on the outputs.
The *positions* in the array are the rankings, the *contents* of the array are the factor indices. This
is the inverse of :meth:ranking.

:Returns:

numpy.ndarray
:math:d vector of order factors.

:attr:ranking
"""
return self._s_mean.argsort()[-1::-1]

@property
def ranking(self) -> np.ndarray:
"""Returns the ranking of each factor being analyzed.
The *positions* in the array are the factor indices, the *contents* of the array are rankings such that
1 is the most influential factor and :math:g+1 is the least influential. This is the inverse of
:meth:order_factors.

:Returns:

numpy.ndarray
:math:d vector of rankings.

:attr:order_factors
"""
return self.order_factors.argsort(0) + 1

[docs]    @classmethod
def load(cls, path: Union[Path, str]) -> Union["HSICResult", "HSICResultWithReweight"]:
"""Load a calculation result from file.

:Parameters:

path
Path to saved result file

:meth:save
"""
path = str(path)

if "_has_reweight" in data and data["_has_reweight"]:
obj = object.__new__(HSICResultWithReweight)
else:
obj = object.__new__(HSICResult)

for k, v in data.items():
setattr(obj, k, v)

# Remove ndarray nesting for some items
# todo consider using pickle instead of numpy to persist this object so this unnesting is not needed
obj._in_kernel = obj._in_kernel.item()
obj._out_kernel = obj._out_kernel.item()
if obj._out_kernel == "SigmoidKernel":
# Hack to automatically rename deprecated sigmoid kernel to new name
obj._out_kernel = "ConjunctiveGaussianKernel"
obj._in_kernel_params = obj._in_kernel_params.item()
obj._out_kernel_params = obj._out_kernel_params.item()
obj._replace = obj._replace.item()
if obj.labels.size == 1 and obj.labels.item() is None:
obj.labels = None
if hasattr(obj, "_has_reweight"):
obj._has_reweight = obj._has_reweight.item()

return obj

@classmethod
def _from_compute(
cls,
hsic: np.ndarray,
input_kernel: BaseKernel,
output_kernel: BaseKernel,
n_bootstraps: int,
n_sample: int,
replace: bool,
has_reweight: bool,
targets: Optional[np.ndarray],
g: Optional[float],
dg: Optional[np.ndarray],
weights: Optional[np.ndarray],
labels: Optional[Sequence[str]],
) -> Union["HSICResult", "HSICResultWithReweight"]:
"""Processes new calculation results from :meth:.HSIC.compute."""
if has_reweight:
obj = object.__new__(HSICResultWithReweight)
else:
obj = object.__new__(HSICResult)

obj._hsic = hsic
obj._in_kernel = type(input_kernel).__name__
obj._out_kernel = type(output_kernel).__name__
obj._n_sample = n_sample
obj._n_boot = n_bootstraps
obj._d = obj._hsic.shape[1]
obj._in_kernel_params = {k: getattr(input_kernel, k) for k in input_kernel.PARAMETERS}
obj._out_kernel_params = {k: getattr(output_kernel, k) for k in output_kernel.PARAMETERS}
obj._replace = replace

obj._has_reweight = has_reweight
obj._targets = targets
obj._g = g
obj._dg = dg
obj._weights = weights

obj.labels = np.array(labels) if labels is not None else None

s = np.clip(hsic, 0, None)
s /= s.sum(1, keepdims=True)
obj._s = s
obj._s_mean = s.mean(0)

return obj

def __init__(self, *args, **kwargs):

[docs]    def __str__(self, top_n: int = 5, width: int = 50) -> str:
"""Returns a summary of the result and its settings."""
rec = f"{'HSIC Result':^{width}}\n"
rec += "=" * width + "\n"
fmt = f"{{0:<{width // 2}}}{{1:>{width // 2}}}\n"
rec += "-" * width + "\n"
rec += fmt.format("No. Factors", self.n_factors)
rec += fmt.format("No. Samples", self.n_samples)
rec += fmt.format("No. Bootstraps", self.n_bootstraps)
rec += fmt.format("Sample with replacement", str(self.sampling_with_replacement))
rec += fmt.format("Includes reweight calc.", "Yes" if self._has_reweight else "No")
if self._has_reweight:
std = f"\u00b1{self.g_std:.03}" if self.n_bootstraps > 1 else ""
rec += fmt.format("Sensitivity Imbalance (g)", f"{self.g:.03}" + std)
rec += "-" * width + "\n"

rec += fmt.format("Inputs Kernel", self.inputs_kernel)
for k, v in self.inputs_kernel_parameters.items():
rec += fmt.format(f"   {k}", v)
rec += fmt.format("Outputs Kernel", self.outputs_kernel)
for k, v in self.outputs_kernel_parameters.items():
rec += fmt.format(f"   {k}", v)

rec += "-" * width + "\n"
order = np.argsort(self.sensitivities)[-1::-1]

top_n = min(top_n, self.n_factors)
if top_n == -1:
rec += f"Factor Rankings:\n"
else:
rec += f"Top {top_n} Factors:\n"
order = order[:top_n]

for i, fact in enumerate(order, 1):
label = self.labels[fact][:18] if self.labels is not None else f"Parameter_{fact:03}"
if self.n_bootstraps > 1:
rec += f"   {{0:3d}}. {{1:>{width - 20}}} {{2:>.3f}}\u00b1{{3:>.3f}}\n".format(
i, label, self.sensitivities[fact], self.sensitivities_std[fact]
)
else:
rec += f"   {{0:3d}}. {{1:>{width - 14}}} {{2:>.3f}}\n".format(i, label, self.sensitivities[fact])

rec += "=" * width + "\n"

return rec

[docs]    def save(self, path: Union[Path, str] = "hsiccalc.npz"):
"""Saves the result to file.
Uses the numpy '.npz' format to save the result attributes (see numpy.savez).

:Parameters:

path
Path to file in which the result will be saved.
"""
path = str(path)
np.savez(path, **self.__dict__)

[docs]    def plot_sensitivities(
self, path: Union[None, Path, str] = "hsicresult.png", plot_top_n: Optional[int] = None
) -> plt.Figure:
"""Create a detailed graphic of the :attr:sensitivities results.

:Parameters:

path
Optional file location in which to save the image. If not provided the image is not saved and only returned.
plot_top_n
The number of factors to include in the plot. Only the plot_top_n most influential factors are included.

:Returns:

matplotlib.figure.Figure
Figure instance allowing the user to further tweak and change the plot as desired.
"""
path = Path(path) if path is not None else None
d = min(plot_top_n, self._d) if plot_top_n else self.sensitivities.size

fig, ax = plt.subplots()
fig: plt.Figure
ax: plt.Axes

fig.set_size_inches(5, np.clip(15 / 20 * d, 2, 15))

order = np.argsort(self.sensitivities)[-1::-1][:d]
labs = self.labels[order] if self.labels is not None else np.arange(self._d)[order]
y = np.arange(labs.size)
for i in range(self.n_bootstraps):
ax.scatter(self._s[i, order], y, marker=".", color="k")
ax.scatter(self.sensitivities[order], y, marker="|", color="r")
ax.set_yticks(range(labs.size))
ax.set_yticklabels(labs)
ax.margins(y=1 / d)

ax.set_xlabel(r"$S_d$")
ax.set_xlim(0, 1)

fig.tight_layout()

if path:
fig.savefig(path)

return fig

[docs]    def plot_sensitivity_trends(
self, path: Union[None, Path, str] = "trend.png", _seed_ax: plt.Figure = None
) -> plt.Figure:
"""Create a summary graphic of the :attr:sensitivities results.
This is a more abstract version of :meth:plot_sensitivities that always includes all parameters.
It is useful to inspect the sensitivity imbalance between parameters as well as the spread between repeats.

:Parameters:

path
Optional file location in which to save the image. If not provided the image is not saved and only returned.

:Returns:

matplotlib.figure.Figure
Figure instance allowing the user to further tweak and change the plot as desired.
"""
path = Path(path) if path is not None else None
d = self.sensitivities.size

if _seed_ax:
ax = _seed_ax
else:
fig, ax = plt.subplots(figsize=(13.5, 3))
fig: plt.Figure
ax: plt.Axes

order = np.argsort(self.sensitivities)[-1::-1][:d]
x = np.arange(d)
for i in range(self.n_bootstraps):
ax.scatter(x, self._s[i, order], marker=".", color="k", label="Raw Result" if i == 0 else None)
ax.plot(x, self.sensitivities[order], color="r", label="Mean")
ax.margins(y=1 / d)

ax.set_xlabel("parameter (sorted by sensitivity)")
ax.set_ylabel("sensitivity")
ax.set_xticklabels([])

ax.legend()

if _seed_ax:
return

fig.tight_layout()

if path:
fig.savefig(path)

return fig

[docs]    def plot_grouped_sensitivities(
self, path: Union[None, Path, str] = "trend.png", squash_threshold: float = 0.0, _seed_ax: plt.Figure = None
) -> plt.Figure:
"""Create a pie chart of the :attr:sensitivities result per factor group.
Assumes labels for the parameters take the format: group:factor_name.

:Parameters:

path
Optional file location in which to save the image. If not provided the image is not saved and only returned.
squash_threshold
If a group's sensitivity falls below this value it will be added to the 'Other' wedge of the plot.

:Returns:

matplotlib.figure.Figure
Figure instance allowing the user to further tweak and change the plot as desired.
"""
grouped = {"Other": 0}
labels = self.labels if self.labels is not None else [str(i) for i in range(self.n_factors)]

for l, s in zip(labels, self.sensitivities):
g = l.split(":")[0]
grouped[g] = grouped.get(g, 0) + s

for g, s in grouped.copy().items():
if s < squash_threshold and g != "Other":
grouped["Other"] += s
del grouped[g]

if _seed_ax:
ax = _seed_ax
else:
fig, ax = plt.subplots(figsize=(13.5, 3))
fig: plt.Figure
ax: plt.Axes

ordered_keys = ["Other"] + [k for _, k in sorted(zip(grouped.values(), grouped.keys())) if k != "Other"]
ordered_vals = [grouped[k] for k in ordered_keys]
ax.pie(ordered_vals, labels=ordered_keys, autopct="%1.1f%%", startangle=90, counterclock=True, normalize=False)
if _seed_ax:
return

fig.tight_layout()

if path:
fig.savefig(path)

return fig

def _plot_sensitivity_summary(
self,
path: Union[None, Path, str] = "sensitivity_summary.png",
squash_threshold: float = 0.0,
width: float = 13.5,
height: float = 3,
) -> plt.Figure:
"""Combines :meth:plot_grouped_sensitivities and :meth:plot_sensitivity_trends into a single plot."""
fig, ax = plt.subplots(1, 2, figsize=(width, height))
fig: plt.Figure
ax: List[plt.Axes]

self.plot_sensitivity_trends(None, _seed_ax=ax[0])
self.plot_grouped_sensitivities(None, squash_threshold=squash_threshold, _seed_ax=ax[1])

fig.tight_layout()
if path:
fig.savefig(path)

return fig

[docs]class HSICResultWithReweight(HSICResult):
@property
def g(self) -> float:
"""Metric of the sensitivity imbalance (:math:g).
:math:g \\geq 0 such that :math:g = 0 means all sensitivities are on target.
If :attr:~.HSICResult.n_bootstraps is larger than one. This represents the mean over the bootstraps.
"""
return self._g.mean(0)

@property
def g_std(self) -> float:
"""Standard deviation of :math:g across all the bootstraps.
Returns None if a reweight calculation was not run.
"""
return self._g.std(0)

@property
def dg(self) -> np.ndarray:
""":math:n_{bootstraps} \\times m matrix of gradients of :math:g with respect to the
:math:m weights in the data set.
"""
return self._dg

@property
def weights(self) -> np.ndarray:
"""Original :math:m length vector of weights used in the loss function.
Returns None if a reweight calculation was not run.
"""
return self._weights

@property
def targets(self) -> np.ndarray:
""":math:d length vector of target sensitivity values for each parameter.
Returns None if a reweight calculation was not run.
"""
return self._targets

@property
def n_residuals(self) -> int:
"""Returns the number of residuals contributing to the loss function."""
return len(self.weights)

[docs]    def plot_reweight(
self,
path: Union[None, Path, str] = "reweightresult.png",
group_by: Optional[Sequence[int]] = None,
labels: Optional[Sequence[str]] = None,
) -> plt.Figure:
"""Create a graphic of the reweight results.

:Parameters:

path
Optional file location in which to save the image. If not provided the image is not saved and only returned.
group_by
Sequence of integer indices grouping individual loss function contributions by training set item. For
example, all the forces in a single data set 'Forces' item.
labels
Sequence of data set item names.

:Returns:

matplotlib.figure.Figure
Figure instance allowing the user to further tweak and change the plot as desired.
"""
group_by = np.array(group_by) if group_by is not None else np.arange(self.n_residuals)
labels = labels if labels else set(group_by)

fig, ax = plt.subplots(1, 2, figsize=(15, np.clip(15 / 40 * len(labels), 2, 15)), sharey="all")
fig: plt.Figure
ax: List[plt.Axes]

ids = np.unique(group_by)
suggested_weights, stats = self.suggest_weights(group_by, True)
dg_mean = self.dg.mean(0)

ax[0].scatter(dg_mean, group_by, marker=".", color="k", label="Raw Results")
ax[0].scatter(stats["medians"], ids, marker="|", color="r", label="Group Median")
ax[0].scatter(stats["means"], ids, marker="|", color="g", label="Group Mean")

std = f"\u00B1{self.g_std:.03}" if self.n_bootstraps > 1 else ""
ax[0].set_title(f"$g(w)=${self.g:.03}{std} (Target: 0)")

ax[0].set_ylabel("Data Set Item")
ax[0].set_yticks(ids)
ax[0].set_yticklabels(labels if labels else ids)

ax[0].set_xlabel("$dg/dw$\n<--- Increase Weight | Decrease Weight --->")

lim = np.abs(dg_mean).max()
ax[0].set_xlim(-lim, lim)

ax[0].axvline(0, color="k", zorder=0)

ax[0].invert_yaxis()
ax[0].margins(y=1 / len(labels))
ax[0].legend(fontsize=10)

# Suggested Weights
ax[1].scatter(stats["original_weights"], ids, marker="o", color="r", label="Original")
ax[1].scatter(suggested_weights, ids, marker="x", color="g", label="Suggested")

ax[1].set_xlabel("Weights")
ax[1].margins(y=1 / len(labels))
ax[1].legend(fontsize=10)

fig.tight_layout()

if path:
fig.savefig(path)

return fig

[docs]    def suggest_weights(
self, group_by: Optional[Sequence[int]] = None, return_stats: bool = False
) -> Union[np.ndarray, Tuple[np.ndarray, Dict[str, np.ndarray]]]:
"""Suggest new weights based on the reweight calculation results.
If :attr:~.HSICResult.n_bootstraps is larger than one, uses the mean :attr:dg value over the bootstraps.

:Parameters:

group_by
Sequence of integer indices grouping individual loss function contributions by training set item. For
example, all the forces in a single data set 'Forces' item.

"""
dg_mean = self.dg.mean(0)

# Group results and calculate suggested weights
group_by = np.array(group_by) if group_by is not None else np.arange(self.n_residuals)

meds = []
means = []
stds = []
original_weights = []
suggested_weights = []
for i, idx in enumerate(set(group_by)):
vals = dg_mean[group_by == i]
meds.append(np.median(vals))
means.append(vals.mean())
stds.append(vals.std())
original_weights.append(self.weights[group_by == i].sum())

change_weight = np.abs(meds) > 0.005  # todo: should this be an option? Is it universally appropriate?
n_change = change_weight.sum()
for i, w in enumerate(original_weights):
if change_weight[i]:
new_weight = w - self.g / n_change / meds[i]
new_weight = max(0.1, new_weight)  # todo: should this be an option? Is it universally appropriate?
suggested_weights.append(new_weight)
else:
suggested_weights.append(w)

if return_stats:
return np.array(suggested_weights), {
"medians": np.array(meds),
"means": np.array(means),
"stds": np.array(stds),
"original_weights": np.array(original_weights),
}

return np.array(suggested_weights)

[docs]    def save_reweight_summary(
self,
path: Union[Path, str] = "reweightresult.csv",
group_by: Optional[Sequence[int]] = None,
labels: Optional[Sequence[str]] = None,
):
"""Save a summary of the reweight calculation to a text file.

:Parameters:

group_by
Sequence of integer indices grouping individual loss function contributions by training set item. For
example, all the forces in a single data set 'Forces' item.
labels
Sequence of data set item names.
"""
path = Path(path)
suggested_weights, stats = self.suggest_weights(group_by, True)

group_by = np.array(group_by) if group_by is not None else np.arange(self.n_residuals)
labels = labels if labels else set(group_by)

with path.open("w") as file:
file.write("#dataset_item dg_median dg_mean dg_std original_weight suggested_weight\n")
for i, lab in enumerate(labels):
file.write(
f"{lab} "
f"{stats['medians'][i]} "
f"{stats['means'][i]} "
f"{stats['stds'][i]} "
f"{stats['original_weights'][i]} "
f"{suggested_weights[i]}\n"
)

[docs]class HSIC:
"""Implementation of the Hilbert-Schmidt Independence Criterion for global sensitivity analysis.
Estimates the HSIC via the unbiased estimator of Song et al (2012).

:Parameters:

x
:math:n \\times d matrix of :math:d-dimensional vectors in the input space. These are samples of the
variables you would like to know the sensitivity for.
y
:math:n length vector of outputs corresponding to the input samples in x. These are the responses of some
function against which the sensitivity is measured.
labels
Optional list of factor names.
x_bounds
:math:d \\times 2 matrix of min-max pairs for every factor in x to scale the values between 0 and 1.
Defaults to 'auto' which takes the limits as the min and max values of the sample data.
To use no scaling set this to None.

:Notes:

y is automatically scaled by taking the logarithm of the values and then scaling by the minimum and maximum to a
range between zero and one. This makes selecting kernel parameters easier, and deals with order-of-magnitude
problems which often arise during reparameterization.

:References:

Song, L., Smola, A., Gretton, A., Bedo, J., & Borgwardt, K. (2012). Feature Selection via Dependence Maximization.
Journal of Machine Learning Research, 13, 1393–1434. https://doi.org/10.5555/2188385.2343691

Gretton, A., Bousquet, O., Smola, A., & Schölkopf, B. (2005). Measuring Statistical Dependence with Hilbert-Schmidt
Norms. In Lecture Notes in Computer Science (including subseries Lecture Notes in Artificial Intelligence and
Lecture Notes in Bioinformatics): Vol. 3734 LNAI (Issue 140, pp. 63–77). Springer, Berlin, Heidelberg.
https://doi.org/10.1007/11564089_7

Gretton, A., Borgwardt, K. M., Rasch, M. J., Smola, A., Schölkopf, B., Smola GRETTON, A., & Smola, A. (2012).
A kernel two-sample test. Journal of Machine Learning Research, 13(25), 723–773.
http://jmlr.org/papers/v13/gretton12a.html

Spagnol, A., Riche, R. Le, & Veiga, S. Da. (2019). Global Sensitivity Analysis for Optimization with Variable
Selection. SIAM/ASA Journal on Uncertainty Quantification, 7(2), 417–443. https://doi.org/10.1137/18M1167978

Da Veiga, S. (2015). Global sensitivity analysis with dependence measures. Journal of Statistical Computation and
Simulation, 85(7), 1283–1305. https://doi.org/10.1080/00949655.2014.945932
"""

def __init__(
self,
x: np.ndarray,
y: np.ndarray,
labels: Optional[Sequence[str]] = None,
x_bounds: Union[str, np.ndarray, None] = "auto",
):
self._x_raw = x.copy()
if x_bounds == "auto":
self.x = (x - x.min(0)) / (x.max(0) - x.min(0))
elif x_bounds is not None:
x_bounds = np.array(x_bounds)
self.x = (x - x_bounds[:, 0]) / (x_bounds[:, 1] - x_bounds[:, 0])
else:
self.x = x.copy()

self._y_raw = y.copy()
self.y = np.log(y)
self._y_min = self.y.min()
self._y_max = self.y.max()
self.y = (self.y - self._y_min) / (self._y_max - self._y_min)

self.n, self.d = x.shape
self.labels = labels

[docs]    def compute(
self,
inputs_kernel: Optional[BaseKernel] = None,
outputs_kernel: Optional[BaseKernel] = None,
n_bootstraps: int = 1,
n_sample: int = -1,
replace: bool = False,
) -> HSICResult:
"""Calculates the HSIC for each input factor.
The larger the HSIC value for a parameter, the more sensitive y is to changes in the corresponding parameter
in x.

:Parameters:

inputs_kernel
Instance of :class:.BaseKernel which will be applied to x. Defaults to :class:!GaussianKernel.
outputs_kernel
Instance of :class:.BaseKernel which will be applied to y. Defaults to
:class:!ConjunctiveGaussianKernel.
n_bootstraps
Number of repeats of the calculation with different sub-samples from the data set. A small spread from a
large number of bootstraps provides confidence on the estimation of the sensitivity.
n_sample
Number of vectors in x to use in the calculation. Defaults to -1 which uses all available points.
replace
If True, samples from x will be done with replacement and vice verse. This only has an effect
if n_sample is less than n otherwise replace is True by necessity.

:Returns:

:class:.HSICResult
Object containing the results of the HSIC calculation and the calculation settings.
"""
return self._compute(
residuals=None,
targets=None,
error_weights=None,
error_sigma=None,
inputs_kernel=inputs_kernel,
outputs_kernel=outputs_kernel,
n_bootstraps=n_bootstraps,
n_sample=n_sample,
replace=replace,
max_cache_size=None,
run_reweight=False,
)

[docs]    def compute_with_reweight(
self,
residuals: np.ndarray,
targets: np.ndarray,
error_weights: np.ndarray,
error_sigma: np.ndarray,
inputs_kernel: Optional[BaseKernel] = None,
outputs_kernel: Optional[BaseKernel] = None,
n_bootstraps: int = 1,
n_sample: int = -1,
replace: bool = False,
max_cache_size: Optional[int] = None,
) -> HSICResultWithReweight:
"""Computes the sensitivity of the HSIC to the weights applied to the construction of an error function.
This calculation is only applicable in very particular conditions:

#. :math:X \\in \\mathbb{R}^{n \\times d} represents the inputs to a function which produces :math:m outputs
for each :math:d length input vector input of :math:X resampled :math:n times.
#. These outputs are the :math:n \\times m predictions matrix (:math:P).
#. The predictions matrix can be condensed to the :math:n length vector y by an 'error'
function:

.. math::

\\mathbf{y} = \\sum^d_i w_i \\left(\\frac{P_i - r_i}{\\sigma_i}\\right)^2

#. :math:\\mathbf{r}, :math:\\mathbf{w} and :math:\\mathbf{\\sigma} are :math:m length vectors of
reference, error_weights, and error_sigma values respectively.
#. This function returns :math:\\frac{dg}{d\\mathbf{w}} which is a measure of how much changes in
:math:\\mathbf{w} (the 'weights' in the error function) will affect :math:g which is a measure of how
close HSIC values are to our target HSIC sensitivities.
#. :math:g is defined such that :math:g \\geq 0. :math:g = 0 implies the sensitivities are perfect.

:Parameters:

residuals
:math:n \\times m matrix of error values between predictions and references.
targets
:math:d length boolean vector. If an element is True, one would like the corresponding parameter to
show sensitivity. Can also send a vector of real values such that sum of elements is 1. This allows for
custom sensitivities to be targeted for each parameter.
error_weights
:math:m length vector of error function 'weights' for which the sensitivity will be measured.
error_sigma
:math:m length vector of error function 'standard error' values.
inputs_kernel
See :meth:compute. But only one value is allowed in this function.
outputs_kernel
See :meth:compute.
n_bootstraps
Number of repeats of the calculation with different sub-samples from the data set. A small spread from a
large number of bootstraps provides confidence on the estimation of the sensitivity.
n_sample
Number of vectors in x to use in the calculation. Defaults to -1 which uses all available points.
replace
If True, samples from x will be done with replacement and vice verse. This only has an effect
if n_sample is less than n otherwise replace is True by necessity.
max_cache_size
Maximum amount of disk space (in bytes) the program may use to store matrices and speed-up the calculation.
Defaults to the maximum size of the temporary directory on the system.

:Returns:

:class:.HSICResultWithReweight
Object containing the results of the HSIC calculation and the calculation settings.

.. warning::

Efforts have been made to reduce the memory footprint of the calculation, but it can become very large,
very quickly.

This calculation is also significantly slower than the normal sensitivity calculation.
"""
return self._compute(
residuals=residuals,
targets=targets,
error_weights=error_weights,
error_sigma=error_sigma,
inputs_kernel=inputs_kernel,
outputs_kernel=outputs_kernel,
n_bootstraps=n_bootstraps,
n_sample=n_sample,
replace=replace,
max_cache_size=max_cache_size,
run_reweight=True,
)

def _compute(
self,
residuals: Optional[np.ndarray],
targets: Optional[np.ndarray],
error_weights: Optional[np.ndarray],
error_sigma: Optional[np.ndarray],
inputs_kernel: Optional[BaseKernel],
outputs_kernel: Optional[BaseKernel],
n_bootstraps: int,
n_sample: int,
replace: bool,
max_cache_size: Optional[int],
run_reweight: bool,
) -> Union[HSICResult, HSICResultWithReweight]:
# Process inputs
if n_sample > self.n:
replace = True
n_sample = n_sample if n_sample > 0 else self.n

if inputs_kernel is None:
inputs_kernel = GaussianKernel()
if outputs_kernel is None:
outputs_kernel = ConjunctiveGaussianKernel()

assert n_bootstraps > 0

all_x = np.arange(self.n)
boot_ids = np.array([np.random.choice(all_x, n_sample, replace=replace) for _ in range(n_bootstraps)])

d = self.d
triu_ids = np.triu_indices(n_sample, 1)

# HSIC Coefficients
coeff_pre = 1 / (n_sample * (n_sample - 3))
coeff_t2 = 1 / ((n_sample - 1) * (n_sample - 2))
coeff_t3 = -2 / (n_sample - 2)

# Pre-allocation
hsic = np.zeros((n_bootstraps, d))

t = None
if run_reweight:
residuals = np.array(residuals)
targets = np.array(targets)
w = np.array(error_weights)
m = residuals.shape[1]

# Setup targets
if targets.dtype == bool:
t = np.zeros(d)
t[targets] = 1 / targets.sum()
else:
t = targets

dhsic = np.zeros((n_bootstraps, d, m))

# Error Function
df = residuals / error_sigma[None]
df **= 2  # n x m

f = w * df
f = f.sum(1)  # n

if not np.allclose(f, self._y_raw):
warnings.warn(
"The total loss calculated from 'residuals' does not match the total loss given. "
"Please check the data and verify that a SSE loss function was used."
)

# Scale f to f_bar
# (This is essential for the conjunctive-Gaussian kernel which relies on scaled data)
f_bar = np.log(f)
f_bar = (f_bar - self._y_min) / (self._y_max - self._y_min)

df_bar = 1 / (self._y_max - self._y_min) * (1 / f[:, None]) * df
else:
f_bar = self.y
m = 1

# Setup Progress Bar
pbar = tqdm(
file=sys.stdout,
bar_format="[ {percentage:3.0f}%|{bar:8}] {n_fmt}/{total_fmt}, {elapsed}<{remaining}, {rate_fmt}{postfix}",
total=n_bootstraps * (m * d + int(run_reweight)),  # Only cache if reweight calculation is run
unit_scale=True,
)

for b in range(n_bootstraps):
x = self.x[boot_ids[b]]
y = f_bar[boot_ids[b]]

poststr = f"Boot {b + 1:03} / {n_bootstraps:03} "

# Output kernel
L = outputs_kernel(y)
L_sum = L.sum()

cache = "none"
if run_reweight:
# Caching only needed for reweight calculation
# d << m (usually there are far fewer dimensions than training set items)
# We will try and pre-calculate them and store them in memory.
# If too large try hold them temporarily in disk,
# If still too large we will just have to recalculate them over and over again.
est_cache_space = d * n_sample * (n_sample - 1) // 2 * 8
buffer_space = 10 * n_sample * n_sample * 8  # Reserve some memory for later calculation
mem_avail = psutil.virtual_memory().available

if mem_avail > (est_cache_space + buffer_space):
K_cache = np.empty((d, n_sample * (n_sample - 1) // 2), dtype=float)
pbar.set_postfix_str(poststr + "Caching to memory", refresh=True)
for i in range(d):
K = inputs_kernel(x[:, i])
K_cache[i] = K[triu_ids]  # Store only upper tri since matrices symmetric
del K
cache = "memory"
else:
tmpdir = tf.TemporaryDirectory()
try:
if max_cache_size is not None and est_cache_space > max_cache_size:
raise OSError

pbar.set_postfix_str(poststr + "Caching to disk", refresh=True)
for i in range(d):
K = inputs_kernel(x[:, i])
np.save(tmpdir.name + os.sep + str(i) + ".npy", K[triu_ids])

cache = "disk"
except OSError:
tmpdir.cleanup()
cache = "none"
pbar.set_postfix_str(poststr + "Unable to cache results")
pbar.update()

pbar.set_postfix_str(poststr + "HSIC Calc", refresh=True)

# Loop over training set items
for j in range(m):  # m == 1 if reweight is not run
if run_reweight:
# Placed outside loop to avoid unneeded recalculations
dL = outputs_kernel(y, df_bar[boot_ids[b], j])
dL_sum = dL.sum()

for i in range(d):
if cache == "memory":
K = np.zeros((n_sample, n_sample))
K[triu_ids] = K_cache[i]
K = K.T
K[triu_ids] = K_cache[i]
elif cache == "disk":
K = np.zeros((n_sample, n_sample))
K = K.T
else:  # cache == 'none'
K = inputs_kernel(x[:, i])
K_sum = K.sum()

# Avoid repeating this calc which is not a function of j but is here so K is only extracted once
if j == 0:
# HSIC calculation
KL = K @ L

t1 = np.trace(KL)  # Term 1: tr(KL)
t2 = K_sum * L_sum  # Term 2: 1K11L1
t3 = np.sum(KL)  # Term 3: 1KL1
del KL

hsic[b, i] = coeff_pre * (t1 + coeff_t2 * t2 + coeff_t3 * t3)

if run_reweight:
# dHSIC Calculation
KdL = K @ dL

t1 = np.trace(KdL)  # Term 1: tr(KL)
t2 = K_sum * dL_sum  # Term 2: 1K11L1
t3 = np.sum(KdL)  # Term 3: 1KL1
del KdL

dhsic[b, i, j] = coeff_pre * (t1 + coeff_t2 * t2 + coeff_t3 * t3)

del K
pbar.update()

# Close cache if used
if cache == "disk":
tmpdir.cleanup()

# Close Progress Bar Nicely
pbar.refresh()
pbar.close()

# Calculate S
hsic = np.clip(hsic, 0, None)
hsic_tot = hsic.sum(1)
s = hsic / hsic_tot[:, None]

g = None
dg = None
if run_reweight:
ds = (hsic_tot[:, None, None] * dhsic - hsic[:, :, None] * dhsic) / (hsic_tot[:, None, None] ** 2)
core = s - t
g = np.sum(core**2, 1)
dg = np.sum(2 * ds * core[:, :, None], 1)

return HSICResult._from_compute(
hsic=hsic,
input_kernel=inputs_kernel,
output_kernel=outputs_kernel,
n_bootstraps=n_bootstraps,
n_sample=n_sample,
replace=replace,
has_reweight=run_reweight,
targets=t,
g=g,
dg=dg,
weights=error_weights,
labels=self.labels,
)