Finding Drivers of Heterogeneous Effects#
Often, we want to know if a heterogeneous effect is isolated to a particular context. This tells us which contexts are driving the heterogeneity in our data, and which predictors in our models depend on which contexts.
This notebook will provide a quick demonstration on how to use the test_each_context
function to get p-values for each context’s effect on heterogeneity in isolation, and interpret the results.
import numpy as np
import pandas as pd
from contextualized.analysis.pvals import test_each_context, get_possible_pvals
from contextualized.easy import ContextualizedRegressor
import logging
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
Simple Simulation: Known Heterogeneity#
When data has known context-dependence and context-invariance for different features, test_each_context
can detect this.
%%capture
# generating training data for a simple example
n_samples = 1000
C = np.random.uniform(0, 1, size=(n_samples, 2))
X = np.random.uniform(0, 1, size=(n_samples, 2))
beta = np.concatenate([np.ones((n_samples, 1)), C], axis=1)
Y = np.sum(beta[:, :2] * X, axis=-1)
C_train_df = pd.DataFrame(C, columns=['C0', 'C1'])
X_train_df = pd.DataFrame(X, columns=['X0', 'X1'])
Y_train_df = pd.DataFrame(Y, columns=['Y'])
pvals = test_each_context(ContextualizedRegressor, C_train_df, X_train_df, Y_train_df, encoder_type="mlp", max_epochs=1, learning_rate=1e-2, n_bootstraps=40)
Analyzing results#
We now have p-values for each of the isolated contextual effects on the predictor variables.
In this setup, X0
is context-invariant, while X1
depends only on context C0
.
We should see that the p-value for X0
is high over all contexts, while the p-value for X1
is low only in context C0
.
These p-values are based on the consistency of the learned effects across multiple bootstraps.
Because of this, the number of bootstraps determines the power of the test.
You can check the range of p-values you can get from different numbers of bootstraps with the get_pval_range
function.
# getting the range of possible p-values for 40 bootstraps
get_possible_pvals(40)
[0.024390243902439025, 0.975609756097561]
pvals
Context | Predictor | Target | Pvals | |
---|---|---|---|---|
0 | C0 | X0 | Y | 0.170732 |
1 | C0 | X1 | Y | 0.024390 |
2 | C1 | X0 | Y | 0.341463 |
3 | C1 | X1 | Y | 0.390244 |
Contexts driving heterogeneity in diabetes diagnoses#
Now we apply our method to a real dataset of diabetes diagnoses. Diabetes is a disease known to be widely heterogeneous, and the pathology can be highly patient-specific.
We apply the test_each_context
function to the diabetes dataset to see which contexts, including patient age, sex, and bmi, are driving heterogeneity in diabetes diagnosis.
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_diabetes
X, Y = load_diabetes(return_X_y=True, as_frame=True)
Y = np.expand_dims(Y.values, axis=-1)
C = X[['age', 'sex', 'bmi']]
X.drop(['age', 'sex', 'bmi'], axis=1, inplace=True)
seed = 1
C, _, X, _, Y, _ = train_test_split(C, X, Y, test_size=0.50, random_state=seed)
# converting to pandas dataframe
C_train_df = pd.DataFrame(C)
X_train_df = pd.DataFrame(X)
Y_train_df = pd.DataFrame(Y)
%%capture
pvals = test_each_context(ContextualizedRegressor, C_train_df, X_train_df, Y_train_df, encoder_type="mlp", max_epochs=20, learning_rate=1e-2, n_bootstraps=40)
get_possible_pvals(40)
[0.024390243902439025, 0.975609756097561]
pvals
Context | Predictor | Target | Pvals | |
---|---|---|---|---|
0 | age | bp | 0 | 0.560976 |
1 | age | s1 | 0 | 0.341463 |
2 | age | s2 | 0 | 0.512195 |
3 | age | s3 | 0 | 0.560976 |
4 | age | s4 | 0 | 0.560976 |
5 | age | s5 | 0 | 0.560976 |
6 | age | s6 | 0 | 0.560976 |
7 | sex | bp | 0 | 0.170732 |
8 | sex | s1 | 0 | 0.414634 |
9 | sex | s2 | 0 | 0.609756 |
10 | sex | s3 | 0 | 0.170732 |
11 | sex | s4 | 0 | 0.170732 |
12 | sex | s5 | 0 | 0.170732 |
13 | sex | s6 | 0 | 0.170732 |
14 | bmi | bp | 0 | 0.024390 |
15 | bmi | s1 | 0 | 0.390244 |
16 | bmi | s2 | 0 | 0.219512 |
17 | bmi | s3 | 0 | 0.024390 |
18 | bmi | s4 | 0 | 0.097561 |
19 | bmi | s5 | 0 | 0.024390 |
20 | bmi | s6 | 0 | 0.024390 |