Fitting And Analyzing Contextualized Models#

Let’s walk through an example of Contextualized analysis.

Download and Prepare Data#

First, we will load the data into a standard pandas dataframe or a numpy array, and create a train / test split. There’s only 1 step of preprocessing required: deciding the context variables..

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_diabetes

X, Y = load_diabetes(return_X_y=True, as_frame=True)
Y = np.expand_dims(Y.values, axis=-1)
C = X[['age', 'sex', 'bmi']]
X.drop(['age', 'sex', 'bmi'], axis=1, inplace=True)

seed = 1
C_train, C_test, X_train, X_test, Y_train, Y_test = train_test_split(C, X, Y, test_size=0.20, random_state=seed)

Train a Contextualized Model#

Contextualized models follow an sklearn-like interface to make training easy.

from contextualized.easy import ContextualizedRegressor

model = ContextualizedRegressor(n_bootstraps=3)
model.fit(C_train.values, X_train.values, Y_train,
          encoder_type="mlp", max_epochs=20,
          learning_rate=1e-3)
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/opt/homebrew/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1789: UserWarning: MPS available but not used. Set `accelerator` and `devices` using `Trainer(accelerator='mps', devices=1)`.
  rank_zero_warn(
/opt/homebrew/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:616: UserWarning: Checkpoint directory /Users/blengerich/Dropbox/Professional/Research/Libraries/Contextualized/docs/lightning_logs/boot_0_checkpoints exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

  | Name      | Type           | Params
---------------------------------------------
0 | metamodel | NaiveMetamodel | 958   
---------------------------------------------
958       Trainable params
0         Non-trainable params
958       Total params
0.004     Total estimated model params size (MB)
/opt/homebrew/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:236: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 10 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/opt/homebrew/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:236: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 10 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
`Trainer.fit` stopped: `max_epochs=20` reached.
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/opt/homebrew/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:616: UserWarning: Checkpoint directory /Users/blengerich/Dropbox/Professional/Research/Libraries/Contextualized/docs/lightning_logs/boot_1_checkpoints exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

  | Name      | Type           | Params
---------------------------------------------
0 | metamodel | NaiveMetamodel | 958   
---------------------------------------------
958       Trainable params
0         Non-trainable params
958       Total params
0.004     Total estimated model params size (MB)
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/opt/homebrew/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:616: UserWarning: Checkpoint directory /Users/blengerich/Dropbox/Professional/Research/Libraries/Contextualized/docs/lightning_logs/boot_2_checkpoints exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

  | Name      | Type           | Params
---------------------------------------------
0 | metamodel | NaiveMetamodel | 958   
---------------------------------------------
958       Trainable params
0         Non-trainable params
958       Total params
0.004     Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=20` reached.

Save/load the trained model.#

from contextualized.utils import save, load

save_path = './easy_demo_model.pt'
save(model, path=save_path)
model = load(save_path)

Inspect the model predictions.#

We can use standard plotting tools to inspect the model predictions.

import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams.update({'font.size': 18})

plt.scatter(Y[:, 0], model.predict(C.values, X.values)[:, 0])
plt.xlabel("True Value")
plt.ylabel("Predicted Value")
plt.show()
/opt/homebrew/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:236: PossibleUserWarning: The dataloader, predict_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 10 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/opt/homebrew/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/prediction_epoch_loop.py:174: UserWarning: Lightning couldn't infer the indices fetched for your dataloader.
  warning_cache.warn("Lightning couldn't infer the indices fetched for your dataloader.")
_images/model-fitting_10_5.png

Check what the individual bootstrap models learned.#

Since we bootstrapping for robustness, we can also access individual bootstrap runs with the individual_preds keyword to get confidence intervals.

model_preds = model.predict(C.values, X.values, individual_preds=True)
for i, pred in enumerate(model_preds):
    plt.scatter(Y[:, 0], pred[:, 0], label='Bootstrap {}'.format(i))
plt.xlabel("True Value")
plt.ylabel("Predicted Value")
plt.legend()
plt.show()
_images/model-fitting_12_3.png

Check what contextualized parameters the models learned.#

betas, mus = model.predict_params(C.values, individual_preds=False)

# To get a sense of how the models are clusterd, let's embed the models is a 2-D space and visualize them.

# Betas are shape n_samples, n_outputs, n_predictors
# Let's squeeze out the n_outputs axis since we only have 1 output in this example.
betas = np.squeeze(betas)

from umap import UMAP
um = UMAP(n_neighbors=5)
model_reps = um.fit_transform(betas)

from contextualized.analysis.embeddings import plot_embedding_for_all_covars
plot_embedding_for_all_covars(model_reps, C, xlabel='Model UMAP 1', ylabel='Model UMAP 2')
_images/model-fitting_14_3.png _images/model-fitting_14_4.png _images/model-fitting_14_5.png

What are the homogeneous predictors?#

from contextualized.analysis.effects import (
    plot_homogeneous_context_effects,
    plot_homogeneous_predictor_effects,
    plot_heterogeneous_predictor_effects,
)

plot_homogeneous_context_effects(
    model, C, 
    classification=False, verbose=False, 
    ylabel="Diabetes Progression")

plot_homogeneous_predictor_effects(
    model, C, X,
    ylabel="Diabetes Progression",
    classification=False)
/opt/homebrew/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:236: PossibleUserWarning: The dataloader, predict_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 10 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
_images/model-fitting_16_4.png _images/model-fitting_16_8.png _images/model-fitting_16_12.png _images/model-fitting_16_16.png _images/model-fitting_16_17.png _images/model-fitting_16_18.png _images/model-fitting_16_19.png _images/model-fitting_16_20.png _images/model-fitting_16_21.png _images/model-fitting_16_22.png

What are the heterogeneous predictors (effects change based on context)?#

plot_heterogeneous_predictor_effects(model, C, X, min_effect_size=10,
                  ylabel="Influence of")
_images/model-fitting_18_9.png _images/model-fitting_18_10.png _images/model-fitting_18_11.png _images/model-fitting_18_12.png _images/model-fitting_18_13.png _images/model-fitting_18_14.png _images/model-fitting_18_15.png