Contextualized Regression#

Contextual models allow to identify heterogeneous effects which change between different contexts. The contextualized package provides easy access to these heterogeneous models through the contextualized.easy module interfaces. This notebook shows a simple example of using contextualized.easy regression.

Note: These contextualized.easy interfaces don’t try to use best PyTorch practices (prioritizing easy use over computational efficiency), so it’s recommended to use contextualized models for any heavy lifting (see this demo).

1. Generate Simulation Data#

In this simulated regression, we’ll study the simple varying-coefficients case where \(Y =X\beta(C) + \epsilon\), with \(\beta(C) \sim \text{N}(C\phi, \text{I})\).

import numpy as np
n_samples = 1000
n_outcomes = 3
n_context = 1
n_observed = 1
C = np.random.uniform(-1, 1, size=(n_samples, n_context))
X = np.random.uniform(0, 0.5, size=(n_samples, n_observed))
phi = np.random.uniform(-1, 1, size=(n_context, n_observed, n_outcomes))
beta = np.tensordot(C, phi, axes=1) + np.random.normal(0, 0.01, size=(n_samples, n_observed, n_outcomes))
Y = np.array([np.tensordot(X[i], beta[i], axes=1) for i in range(n_samples)])

2. Build and fit model#

Uses an sklearn-type wrapper interface.

from contextualized.easy import ContextualizedRegressor
model = ContextualizedRegressor()
model.fit(C, X, Y, max_epochs=20, learning_rate=1e-3, n_bootstraps=3)
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/opt/homebrew/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:1789: UserWarning: MPS available but not used. Set `accelerator` and `devices` using `Trainer(accelerator='mps', devices=1)`.
  rank_zero_warn(
/opt/homebrew/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:616: UserWarning: Checkpoint directory /Users/blengerich/Dropbox/Professional/Research/Libraries/Contextualized/docs/demos/lightning_logs/boot_0_checkpoints exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

  | Name      | Type           | Params
---------------------------------------------
0 | metamodel | NaiveMetamodel | 856   
---------------------------------------------
856       Trainable params
0         Non-trainable params
856       Total params
0.003     Total estimated model params size (MB)
/opt/homebrew/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:236: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 10 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/opt/homebrew/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:236: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 10 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/opt/homebrew/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:616: UserWarning: Checkpoint directory /Users/blengerich/Dropbox/Professional/Research/Libraries/Contextualized/docs/demos/lightning_logs/boot_1_checkpoints exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

  | Name      | Type           | Params
---------------------------------------------
0 | metamodel | NaiveMetamodel | 856   
---------------------------------------------
856       Trainable params
0         Non-trainable params
856       Total params
0.003     Total estimated model params size (MB)
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/opt/homebrew/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:616: UserWarning: Checkpoint directory /Users/blengerich/Dropbox/Professional/Research/Libraries/Contextualized/docs/demos/lightning_logs/boot_2_checkpoints exists and is not empty.
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

  | Name      | Type           | Params
---------------------------------------------
0 | metamodel | NaiveMetamodel | 856   
---------------------------------------------
856       Trainable params
0         Non-trainable params
856       Total params
0.003     Total estimated model params size (MB)

Common constructor keywords include:

- n_bootstraps: The integer number of bootstrap runs to fit.
- encoder_type: 'mlp' or 'ngam', which type of model to make as context encoder.
- loss_fn: A function to calculate loss (see 'contextualized.regression.LOSSES')
- alpha: non-negative float, regularization strength.
- mu_ratio: float in range (0.0, 1.0), governs how much the regularization applies to context-specific parameters or context-specific offsets.
- l1_ratio: float in range (0.0, 1.0), governs how much the regularization penalizes l1 vs l2 parameter norms.

Common fitting keywords include:

- max_epochs: positive number, the maximum number of epochs to fit. Early stopping is turned on by default.
- learning_rate: positive float, default is 1e-3.
- val_split: float in range (0.0, 1.0), how much of the data to use for validation (early stopping).

3. Inspect the model predictions.#

import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams.update({'font.size': 18})

plt.scatter(Y[:, 0], model.predict(C, X)[:, 0])
plt.xlabel("True Value")
plt.ylabel("Predicted Value")
plt.show()
/opt/homebrew/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:236: PossibleUserWarning: The dataloader, predict_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 10 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
/opt/homebrew/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/prediction_epoch_loop.py:174: UserWarning: Lightning couldn't infer the indices fetched for your dataloader.
  warning_cache.warn("Lightning couldn't infer the indices fetched for your dataloader.")
../_images/easy_regression_7_5.png

4. Check what the individual bootstrap models learned.#

model_preds = model.predict(C, X, individual_preds=True)
for i, pred in enumerate(model_preds):
    plt.scatter(Y[:, 0], pred[:, 0], label='Bootstrap {}'.format(i))
plt.xlabel("True Value")
plt.ylabel("Predicted Value")
plt.legend()
plt.show()
../_images/easy_regression_9_3.png

5. Check what parameters the models learned.#

beta_preds, mu_preds = model.predict_params(C, individual_preds=False)
plt.scatter(C, beta_preds[:, 0], label='Beta Predicted')
plt.scatter(C, beta[:, :, 0], label='Beta Actual')
plt.scatter(C, mu_preds[:, 0], label='Mu Predicted')
plt.scatter(C, np.zeros_like(C), label='Mu Actual')
plt.legend(loc='center right', bbox_to_anchor=(1.6, 0.5), fontsize=16)
plt.xlabel("C value")
plt.ylabel("Parameter value")
plt.show()
../_images/easy_regression_11_3.png

6. Save/load the trained model.#

from contextualized.utils import save, load

save_path = './easy_demo_model.pt'
save(model, path=save_path)
model = load(save_path)