Source code for contextualized.easy.ContextualGAM

"""
Contextual Generalized Additive Model.
See https://www.sciencedirect.com/science/article/pii/S1532046422001022
for more details.
"""

from contextualized.easy import ContextualizedClassifier, ContextualizedRegressor


[docs]class ContextualGAMClassifier(ContextualizedClassifier): """ The Contextual GAM Classifier separates and interprets the effect of context in context-varying decisions and classifiers, such as heterogeneous disease diagnoses. Implemented as a Contextual Generalized Additive Model with a classifier on top. Always uses a Neural Additive Model ("ngam") encoder for interpretability. See `this paper <https://www.sciencedirect.com/science/article/pii/S1532046422001022>`__ for more details. Args: n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1. num_archetypes (int, optional): Number of archetypes to use. Defaults to 0, which used the NaiveMetaModel. If > 0, uses archetypes in the ContextualizedMetaModel. alpha (float, optional): Regularization strength. Defaults to 0.0. mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. Defaults to 0.0. l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. Defaults to 0.0. """ def __init__(self, **kwargs): kwargs["encoder_type"] = "ngam" super().__init__(**kwargs)
[docs]class ContextualGAMRegressor(ContextualizedRegressor): """ The Contextual GAM Regressor separates and interprets the effect of context in context-varying relationships, such as heterogeneous treatment effects. Implemented as a Contextual Generalized Additive Model with a linear regressor on top. Always uses a Neural Additive Model ("ngam") encoder for interpretability. See `this paper <https://www.sciencedirect.com/science/article/pii/S1532046422001022>`__ for more details. Args: n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1. num_archetypes (int, optional): Number of archetypes to use. Defaults to 0, which used the NaiveMetaModel. If > 0, uses archetypes in the ContextualizedMetaModel. alpha (float, optional): Regularization strength. Defaults to 0.0. mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. Defaults to 0.0. l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. Defaults to 0.0. """ def __init__(self, **kwargs): kwargs["encoder_type"] = "ngam" super().__init__(**kwargs)