Fitting Contextualized Models#

Here, we walk through a an example of fitting a contextualized linear regression model to diabetes data.

Download and Prepare Data#

First, we will load the data into a standard pandas dataframe or a numpy array, and create a train / test split. There’s only 1 step of preprocessing required: deciding the context variables.

Note

Deciding context variables is an experiment-driven question.

Since we will typically interpret the contextualized models in terms of the predictor, it’s often helpful to use interpretable variables as the predictors.

In this example, we will use age, sex, and BMI as the contexts to look for context-specific predictors of diabetes progression.

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_diabetes

# Load data
X, Y = load_diabetes(return_X_y=True, as_frame=True)
Y = np.expand_dims(Y.values, axis=-1)

# Decide context variables, or use other paired data as context
C = X[['age', 'sex', 'bmi']]
X.drop(['age', 'sex', 'bmi'], axis=1, inplace=True)

# Create a hold-out test set
seed = 1
C_train, C_test, X_train, X_test, Y_train, Y_test = train_test_split(C, X, Y, test_size=0.20, random_state=seed)

# Normalize the data
def normalize(train, test):
    mean = train.mean()
    std = train.std()
    train = (train - mean) / std
    test = (test - mean) / std
    return train, test
X_train, X_test = normalize(X_train, X_test)
C_train, C_test = normalize(C_train, C_test)
Y_train, Y_test = normalize(Y_train, Y_test)

Train a Contextualized Model#

Contextualized models follow an SKLearn-style interface to make fitting, predicting, and testing simple.

Note

Common constructor keywords for most models include:

  • n_bootstraps: The integer number of bootstrap runs to fit. Useful for estimating uncertainty and reducing overfitting, and many contextualized.analysis tests use bootstraps to determine significance, but more bootstraps takes longer. Default is 1, which means no bootstrapping.

  • encoder_type: Which type of context encoder to use. Can be ‘mlp’ (multi-layer perceptron, i.e. vanilla neural network) or ‘ngam’ (a neural generalized additive model, i.e. a neural network with feature additivity constraints for interpretability). Default is ‘mlp’.

  • num_archetypes: Degrees of freedom for the context-specific model parameters, defined by a set of learnable model archetypes. Useful for reducing overfitting and learning key axes of context-dependent variation. Default is 0, which allows full degrees of freedom.

  • alpha: non-negative float, regularization strength. Default is no 0.0, no regularization.

  • l1_ratio: float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms. Default is 0.0, which means pure l2 regularization with alpha strength.

  • mu_ratio: float in range (0.0, 1.0), governs how much the regularization applies to context-specific model parameters or context-specific offsets. Default is 0.0, only parameter regularization.

Common fitting keywords include:

  • val_split: float in range (0.0, 1.0), the fraction of the training data to use as a validation set for guaging generalization. Default is 0.2.

  • max_epochs: positive number, the maximum number of epochs (iterations over the training data) to fit. Default is 1.

  • es_patience: positive number, the number of epochs to wait for no improvement on the validation set before early stopping. Default is 1.

  • learning_rate: positive float, default is 1e-3.

Common predict keywords include:

  • individual_preds: Whether to return individual predictions for each bootstrap. Defaults to False, averaging across bootstraps.

Common predict_params keywords include:

  • individual_preds: Whether to return individual parameter predictions for each bootstrap. Defaults to False, averaging across bootstraps.

  • model_includes_mus: Whether the predicted context-specific model includes context-specific offsets (mu). Defaults to True.

Please see the API documentation for the specific model you are using for more details.

%%capture
from contextualized.easy import ContextualizedRegressor
model = ContextualizedRegressor(n_bootstraps=20)  # Many bootstraps for later analysis
model.fit(C_train.values, X_train.values, Y_train,
          encoder_type="mlp", max_epochs=10,
          learning_rate=1e-2)

# Get predicted context-specific regression model parameters
contextualized_coeffs, contextualized_offsets = model.predict_params(C_test.values)

# Get the predicted outcomes using the context-specific regression models
Y_pred = model.predict(C_test.values, X_test.values)

Test the Model#

from sklearn.metrics import mean_squared_error, r2_score
print(f'Mean-squared Error: {mean_squared_error(Y_test, Y_pred)}')
print(f'Correlation: {r2_score(Y_test, Y_pred)}')
Mean-squared Error: 0.5375812218168962
Correlation: 0.3872898219663774

Save the trained model for follow-up analysis#

from contextualized.utils import save, load

save(model, path='my_contextualized_model.pt')

In the next step, we will analyze what this model has learned.