Source code for contextualized.analysis.bootstraps
import numpy as np
from contextualized.easy.wrappers import SKLearnWrapper
[docs]def select_good_bootstraps(
sklearn_wrapper: SKLearnWrapper, train_errs: np.ndarray, tol: float = 2
) -> SKLearnWrapper:
"""
Prune any divergent or bad bootstraps with mean training errors below tol * min(training errors).
Args:
sklearn_wrapper (contextualized.easy.wrappers.SKLearnWrapper): Wrapper for the sklearn model.
train_errs (np.ndarray): Training errors for each bootstrap (n_bootstraps, n_samples, n_outcomes).
tol (float): Only bootstraps with mean train_errs below tol * min(train_errs) are kept.
Returns:
contextualized.easy.wrappers.SKLearnWrapper: The input model with only selected bootstraps.
"""
if len(train_errs.shape) == 2:
train_errs = train_errs[:, :, None]
train_errs_by_bootstrap = np.mean(train_errs, axis=(1, 2))
train_errs_min = np.min(train_errs_by_bootstrap)
sklearn_wrapper.models = [
model
for train_err, model in zip(train_errs_by_bootstrap, sklearn_wrapper.models)
if train_err < train_errs_min * tol
]
sklearn_wrapper.n_bootstraps = len(sklearn_wrapper.models)
return sklearn_wrapper