Source code for contextualized.easy.ContextualGAM
"""
Contextual Generalized Additive Model.
See https://www.sciencedirect.com/science/article/pii/S1532046422001022
for more details.
"""
from contextualized.easy import ContextualizedClassifier, ContextualizedRegressor
[docs]class ContextualGAMClassifier(ContextualizedClassifier):
    """
    The Contextual GAM Classifier separates and interprets the effect of context in context-varying decisions and classifiers, such as heterogeneous disease diagnoses.
    Implemented as a Contextual Generalized Additive Model with a classifier on top.
    Always uses a Neural Additive Model ("ngam") encoder for interpretability.
    See `this paper <https://www.sciencedirect.com/science/article/pii/S1532046422001022>`__
    for more details.
    Args:
        n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1.
        num_archetypes (int, optional): Number of archetypes to use. Defaults to 0, which used the NaiveMetaModel. If > 0, uses archetypes in the ContextualizedMetaModel.
        alpha (float, optional): Regularization strength. Defaults to 0.0.
        mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. Defaults to 0.0.
        l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. Defaults to 0.0.
    """
    def __init__(self, **kwargs):
        kwargs["encoder_type"] = "ngam"
        super().__init__(**kwargs) 
[docs]class ContextualGAMRegressor(ContextualizedRegressor):
    """
    The Contextual GAM Regressor separates and interprets the effect of context in context-varying relationships, such as heterogeneous treatment effects.
    Implemented as a Contextual Generalized Additive Model with a linear regressor on top.
    Always uses a Neural Additive Model ("ngam") encoder for interpretability.
    See `this paper <https://www.sciencedirect.com/science/article/pii/S1532046422001022>`__
    for more details.
    Args:
        n_bootstraps (int, optional): Number of bootstraps to use. Defaults to 1.
        num_archetypes (int, optional): Number of archetypes to use. Defaults to 0, which used the NaiveMetaModel. If > 0, uses archetypes in the ContextualizedMetaModel.
        alpha (float, optional): Regularization strength. Defaults to 0.0.
        mu_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets. Defaults to 0.0.
        l1_ratio (float, optional): Float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. Defaults to 0.0.
    """
    def __init__(self, **kwargs):
        kwargs["encoder_type"] = "ngam"
        super().__init__(**kwargs)