Analyzing Contextualized Models#

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_diabetes

X, Y = load_diabetes(return_X_y=True, as_frame=True)
Y = np.expand_dims(Y.values, axis=-1)
C = X[['age', 'sex', 'bmi']]
X.drop(['age', 'sex', 'bmi'], axis=1, inplace=True)

seed = 1
C_train, C_test, X_train, X_test, Y_train, Y_test = train_test_split(C, X, Y, test_size=0.20, random_state=seed)

Save/load the trained model.#

from contextualized.utils import load

save_path = './easy_demo_model.pt'
model = load(save_path)

Inspect the model predictions.#

We can use standard plotting tools to inspect the model predictions.

%%capture
predicted_probs = model.predict(C.values, X.values)[:, 0]
import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams.update({'font.size': 18})

plt.scatter(Y[:, 0], predicted_probs)
plt.xlabel("True Value")
plt.ylabel("Predicted Value")
plt.show()
_images/bbbd874c689f1772afc4548ff403b350f459c01d6f9efe00a9494f6478701691.png

Check how the individual bootstrap models learned.#

Since we’re bootstrapping for robustness, we can also access individual bootstrap runs with the individual_preds keyword to get confidence intervals.

%%capture
model_preds = model.predict(C.values, X.values, individual_preds=True)
model_preds.shape  # (n_bootstraps, n_samples, n_outputs)
plt.plot(Y[:, 0], Y[:, 0], color='black', linestyle='--', label='Perfect')
plt.errorbar(Y[0, 0], np.mean(model_preds[:, 0, 0], axis=0),
                yerr=2*np.std(model_preds[:, 0, 0], axis=0), color='blue',
            label='Predicted')
for i in range(model_preds.shape[1]):
    plt.errorbar(Y[i, 0], np.mean(model_preds[:, i, 0], axis=0),
                yerr=2*np.std(model_preds[:, i, 0], axis=0), color='blue')
plt.xlabel("True Value")
plt.ylabel("Predicted Value")
plt.legend(fontsize=12)
plt.show()
_images/f5c4c580df417a5e0ad0e0f131c10a918de77c936ff55f1985408c09f09ecccd.png

Check what effects the models learned.#

# First, to get a sense of how the models are clustered, let's embed 
# the model parameters in a 2-D space and visualize them.
%%capture
betas, mus = model.predict_params(C.values, individual_preds=False)
betas.shape # (n_samples, n_outputs, n_predictors)
mus.shape  # (n_samples, n_outputs)

# Betas are shape:
# (n_samples, n_outputs, n_predictors)
# Let's squeeze out the n_outputs axis since we only have 1 output in this example.
betas = np.squeeze(betas)

# Any embedding method could be used; here we will use UMAP.
from umap import UMAP
um = UMAP(n_neighbors=5)
model_reps = um.fit_transform(betas)
# A simple helper function is provided in the analysis toolkit for plotting embeddings.
from contextualized.analysis.embeddings import plot_embedding_for_all_covars
plot_embedding_for_all_covars(model_reps, C, xlabel='Model UMAP 1', ylabel='Model UMAP 2')
_images/9f060503bab121e0547e2a756f46866f16303dea1afb5655a02541dd7f685bb4.png _images/7eccc273ec542b7a83544d36b3e1397999436db76b598d1316c5d7cc88ae1213.png _images/5e3e68202908bf146f36ef4c7dad0f09c7253454fd6a2b3f635523db3781ddb9.png

There are three types of effects in a contextualized model:

  • Homogeneous Context Effects:

    • \(f(C)\)

  • Homogeneous Predictor Effects:

    • \(\beta X\)

  • Heterogeneous Predictor Effects:

    • \(g(C)X\)

    • These are effects of predictors which are modulated by context

Helper analysis tools are provided to analyze each of these effects by:

  • getting the values of the effects for a domain of context

  • plotting the values of the effect for a domain of context

  • calculating p-values of the effect for a domain of context

from contextualized.analysis.effects import (
    get_homogeneous_context_effects,
    get_homogeneous_predictor_effects,
    get_heterogeneous_predictor_effects,
    plot_homogeneous_context_effects,
    plot_homogeneous_predictor_effects,
    plot_heterogeneous_predictor_effects,
)

from contextualized.analysis.pvals import (
    calc_homogeneous_context_effects_pvals,
    calc_homogeneous_predictor_effects_pvals,
    calc_heterogeneous_predictor_effects_pvals
)

What are the homogeneous predictors of diabetes progression?#

First, we can see the homogeneous effects of context.#

plot_homogeneous_context_effects(
    model, C, classification=False,
    ylabel="Diabetes Progression", verbose=False)
/opt/homebrew/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:430: PossibleUserWarning: The dataloader, predict_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 10 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
_images/bb786ad6b28da3c11e63abc5969dde6a1bb685f336abc62bc3a11c4890e606c4.png _images/10e0ecf2f2725469da09de9e144192cb82998dba59ba0a501423378f0bb2c473.png _images/e69cca3cc0b6b542e7355d0709f70be9821ce540b5be2bd1b92791f0de47c27a.png
%%capture
# We can quantify the uncertainty visualized in the bootstrap confidence intervals
# by measuring p-values of the consistency of the estimated sign of each effect.
homogeneous_context_pvals = calc_homogeneous_context_effects_pvals(model, C)
context_pvals = pd.DataFrame(np.array([
   [C.columns[i], pval[0]] for i, pval in enumerate(homogeneous_context_pvals)
]), columns=["Context Variable", "p-value"])
context_pvals
# Notice that the p-values are calculated from bootstrap resampling,
# so they are limited by the number of bootstraps used.
# Caution: using more bootstraps could artificially inflate confidence.
Context Variable p-value
0 age 0.09090909090909091
1 sex 0.09090909090909091
2 bmi 0.09090909090909091

Second, we can see the homogeneous effects of the predictors.#

plot_homogeneous_predictor_effects(
    model, C, X,
    ylabel="Diabetes Progression",
    classification=False)
/opt/homebrew/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:430: PossibleUserWarning: The dataloader, predict_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 10 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
_images/ec1758d96ac1391a4109eed20c33d9b6b42985b49f8514de9236b2cae7deb32a.png _images/58cd8598fc5a7ccdbe8de45a61458236c06b4858b183a495465b8b6baa0132df.png _images/5e369fc00599d9883415ad175713d0aa66efefd7786fdd60d21729e2a8f442b5.png _images/2c0478377a070568614f9c41191a39c3257e334ef717d4203a6695ba5c457e42.png _images/e3376ecef8b64362d4b877eded84063924863807f27dfedff93f0bf26f6b5210.png _images/d63e55afde72ad81e4e14a34159897c2c931ba814d613bad311f6c76e6e625d2.png _images/d5a392821f00f4b68a9a3ee954d5a2442d66105203c5decdb858f83c2d224341.png
%%capture
# We can quantify the uncertainty visualized in the bootstrap confidence intervals
# by measuring p-values of the consistency of the estimated sign of each effect.
homogeneous_predictor_pvals = calc_homogeneous_predictor_effects_pvals(model, C)
predictor_pvals = pd.DataFrame(np.array([
   [X.columns[i], pval[0]] for i, pval in enumerate(homogeneous_predictor_pvals)
]), columns=["Predictor", "p-value"])
predictor_pvals
Predictor p-value
0 bp 0.09090909090909091
1 s1 0.09090909090909091
2 s2 0.09090909090909091
3 s3 0.09090909090909091
4 s4 0.09090909090909091
5 s5 0.09090909090909091
6 s6 0.09090909090909091

Finally, we can see the heterogeneous effects (effects of predictors than change based on context).#

plot_heterogeneous_predictor_effects(model, C, X, min_effect_size=75,
                  ylabel="Influence of")
# Since there are a combinatorial number of heterogeneous predictor effects,
# min_effect_size is a useful parameter to restrict the plotting to
# only the strongest effects.
Generating datapoints for visualization by assuming the encoder is
            an additive model and thus doesn't require sampling on a manifold.
            If the encoder has interactions, please supply C_vis so that we
            can visualize these effects on the correct data manifold.
/opt/homebrew/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:430: PossibleUserWarning: The dataloader, predict_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 10 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
_images/881ff67e70d5a474910e0f6bf863b524132680d01036b8b4e253c425b37a48f7.png _images/13a2fcdff335e7f14faf884d74bce3c3704994271504560f98ce3ed6edc9b431.png _images/e338cbb515fa6b91b94aaadbe724ac883d993c75b49e413f12dedb1b7f9a589d.png _images/df1a27a52f5ec9244bbde089d8501253c6762356812b2a8c840d07ec22364541.png _images/28df945a1705ba436b429a681e62c55cee2d294bd50380e1fb42e825af980830.png _images/d26658f966cd2f4b8136d46ccb2c49e911aca789b4ea585b6dd0e1a4f3f23617.png _images/84e824806dc4b5a78788471ac81d60ca3c1e4f628fe1c851de6861f22288c4a6.png
%%capture
# We can quantify the uncertainty visualized in the bootstrap confidence intervals
# by measuring p-values of the consistency of the estimated sign of each effect.
# This is a combinatorial: context x predictor
heterogeneous_predictor_pvals = calc_heterogeneous_predictor_effects_pvals(model, C)
predictor_pvals = pd.DataFrame(np.array([
   [C.columns[i // len(X.columns)], X.columns[i % len(X.columns)], pval] for i, pval in enumerate(heterogeneous_predictor_pvals.flatten())
]), columns=["Context", "Predictor", "p-value"])
predictor_pvals
Context Predictor p-value
0 age bp 0.09090909090909091
1 age s1 0.09090909090909091
2 age s2 0.09090909090909091
3 age s3 0.09090909090909091
4 age s4 0.09090909090909091
5 age s5 0.09090909090909091
6 age s6 0.09090909090909091
7 sex bp 0.09090909090909091
8 sex s1 0.09090909090909091
9 sex s2 0.09090909090909091
10 sex s3 0.09090909090909091
11 sex s4 0.09090909090909091
12 sex s5 0.09090909090909091
13 sex s6 0.09090909090909091
14 bmi bp 0.09090909090909091
15 bmi s1 0.09090909090909091
16 bmi s2 0.09090909090909091
17 bmi s3 0.09090909090909091
18 bmi s4 0.09090909090909091
19 bmi s5 0.09090909090909091
20 bmi s6 0.09090909090909091