{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Finding Drivers of Heterogeneous Effects\n",
"Often, we want to know if a heterogeneous effect is isolated to a particular context. \n",
"This tells us which contexts are driving the heterogeneity in our data, and which predictors in our models depend on which contexts.\n",
"\n",
"This notebook will provide a quick demonstration on how to use the `test_each_context` function to get p-values for each context's effect on heterogeneity in isolation, and interpret the results. "
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"from contextualized.analysis.pvals import test_each_context, get_possible_pvals\n",
"from contextualized.easy import ContextualizedRegressor\n",
"\n",
"import logging\n",
"logging.getLogger(\"pytorch_lightning\").setLevel(logging.ERROR)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Simple Simulation: Known Heterogeneity\n",
"When data has known context-dependence and context-invariance for different features, `test_each_context` can detect this.\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"%%capture\n",
"# generating training data for a simple example\n",
"n_samples = 1000\n",
"C = np.random.uniform(0, 1, size=(n_samples, 2))\n",
"X = np.random.uniform(0, 1, size=(n_samples, 2))\n",
"beta = np.concatenate([np.ones((n_samples, 1)), C], axis=1)\n",
"Y = np.sum(beta[:, :2] * X, axis=-1)\n",
"\n",
"C_train_df = pd.DataFrame(C, columns=['C0', 'C1'])\n",
"X_train_df = pd.DataFrame(X, columns=['X0', 'X1'])\n",
"Y_train_df = pd.DataFrame(Y, columns=['Y'])\n",
"\n",
"pvals = test_each_context(ContextualizedRegressor, C_train_df, X_train_df, Y_train_df, encoder_type=\"mlp\", max_epochs=1, learning_rate=1e-2, n_bootstraps=40)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Analyzing results\n",
"\n",
"We now have p-values for each of the isolated contextual effects on the predictor variables. \n",
"In this setup, `X0` is context-invariant, while `X1` depends only on context `C0`.\n",
"We should see that the p-value for `X0` is high over all contexts, while the p-value for `X1` is low only in context `C0`.\n",
"\n",
"These p-values are based on the consistency of the learned effects across multiple bootstraps.\n",
"Because of this, the number of bootstraps determines the power of the test. \n",
"You can check the range of p-values you can get from different numbers of bootstraps with the `get_pval_range` function."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[0.024390243902439025, 0.975609756097561]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# getting the range of possible p-values for 40 bootstraps\n",
"get_possible_pvals(40)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Context | \n",
" Predictor | \n",
" Target | \n",
" Pvals | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" C0 | \n",
" X0 | \n",
" Y | \n",
" 0.170732 | \n",
"
\n",
" \n",
" 1 | \n",
" C0 | \n",
" X1 | \n",
" Y | \n",
" 0.024390 | \n",
"
\n",
" \n",
" 2 | \n",
" C1 | \n",
" X0 | \n",
" Y | \n",
" 0.341463 | \n",
"
\n",
" \n",
" 3 | \n",
" C1 | \n",
" X1 | \n",
" Y | \n",
" 0.390244 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Context Predictor Target Pvals\n",
"0 C0 X0 Y 0.170732\n",
"1 C0 X1 Y 0.024390\n",
"2 C1 X0 Y 0.341463\n",
"3 C1 X1 Y 0.390244"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pvals"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Contexts driving heterogeneity in diabetes diagnoses\n",
"\n",
"Now we apply our method to a real dataset of diabetes diagnoses.\n",
"Diabetes is a disease known to be widely heterogeneous, and the pathology can be highly patient-specific.\n",
"\n",
"We apply the `test_each_context` function to the diabetes dataset to see which contexts, including patient age, sex, and bmi, are driving heterogeneity in diabetes diagnosis."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"from sklearn.datasets import load_diabetes\n",
"\n",
"X, Y = load_diabetes(return_X_y=True, as_frame=True)\n",
"Y = np.expand_dims(Y.values, axis=-1)\n",
"C = X[['age', 'sex', 'bmi']]\n",
"X.drop(['age', 'sex', 'bmi'], axis=1, inplace=True)\n",
"\n",
"seed = 1\n",
"C, _, X, _, Y, _ = train_test_split(C, X, Y, test_size=0.50, random_state=seed)\n",
"\n",
"# converting to pandas dataframe\n",
"C_train_df = pd.DataFrame(C)\n",
"X_train_df = pd.DataFrame(X)\n",
"Y_train_df = pd.DataFrame(Y)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"%%capture\n",
"pvals = test_each_context(ContextualizedRegressor, C_train_df, X_train_df, Y_train_df, encoder_type=\"mlp\", max_epochs=20, learning_rate=1e-2, n_bootstraps=40)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[0.024390243902439025, 0.975609756097561]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_possible_pvals(40)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" Context | \n",
" Predictor | \n",
" Target | \n",
" Pvals | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" age | \n",
" bp | \n",
" 0 | \n",
" 0.560976 | \n",
"
\n",
" \n",
" 1 | \n",
" age | \n",
" s1 | \n",
" 0 | \n",
" 0.341463 | \n",
"
\n",
" \n",
" 2 | \n",
" age | \n",
" s2 | \n",
" 0 | \n",
" 0.512195 | \n",
"
\n",
" \n",
" 3 | \n",
" age | \n",
" s3 | \n",
" 0 | \n",
" 0.560976 | \n",
"
\n",
" \n",
" 4 | \n",
" age | \n",
" s4 | \n",
" 0 | \n",
" 0.560976 | \n",
"
\n",
" \n",
" 5 | \n",
" age | \n",
" s5 | \n",
" 0 | \n",
" 0.560976 | \n",
"
\n",
" \n",
" 6 | \n",
" age | \n",
" s6 | \n",
" 0 | \n",
" 0.560976 | \n",
"
\n",
" \n",
" 7 | \n",
" sex | \n",
" bp | \n",
" 0 | \n",
" 0.170732 | \n",
"
\n",
" \n",
" 8 | \n",
" sex | \n",
" s1 | \n",
" 0 | \n",
" 0.414634 | \n",
"
\n",
" \n",
" 9 | \n",
" sex | \n",
" s2 | \n",
" 0 | \n",
" 0.609756 | \n",
"
\n",
" \n",
" 10 | \n",
" sex | \n",
" s3 | \n",
" 0 | \n",
" 0.170732 | \n",
"
\n",
" \n",
" 11 | \n",
" sex | \n",
" s4 | \n",
" 0 | \n",
" 0.170732 | \n",
"
\n",
" \n",
" 12 | \n",
" sex | \n",
" s5 | \n",
" 0 | \n",
" 0.170732 | \n",
"
\n",
" \n",
" 13 | \n",
" sex | \n",
" s6 | \n",
" 0 | \n",
" 0.170732 | \n",
"
\n",
" \n",
" 14 | \n",
" bmi | \n",
" bp | \n",
" 0 | \n",
" 0.024390 | \n",
"
\n",
" \n",
" 15 | \n",
" bmi | \n",
" s1 | \n",
" 0 | \n",
" 0.390244 | \n",
"
\n",
" \n",
" 16 | \n",
" bmi | \n",
" s2 | \n",
" 0 | \n",
" 0.219512 | \n",
"
\n",
" \n",
" 17 | \n",
" bmi | \n",
" s3 | \n",
" 0 | \n",
" 0.024390 | \n",
"
\n",
" \n",
" 18 | \n",
" bmi | \n",
" s4 | \n",
" 0 | \n",
" 0.097561 | \n",
"
\n",
" \n",
" 19 | \n",
" bmi | \n",
" s5 | \n",
" 0 | \n",
" 0.024390 | \n",
"
\n",
" \n",
" 20 | \n",
" bmi | \n",
" s6 | \n",
" 0 | \n",
" 0.024390 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" Context Predictor Target Pvals\n",
"0 age bp 0 0.560976\n",
"1 age s1 0 0.341463\n",
"2 age s2 0 0.512195\n",
"3 age s3 0 0.560976\n",
"4 age s4 0 0.560976\n",
"5 age s5 0 0.560976\n",
"6 age s6 0 0.560976\n",
"7 sex bp 0 0.170732\n",
"8 sex s1 0 0.414634\n",
"9 sex s2 0 0.609756\n",
"10 sex s3 0 0.170732\n",
"11 sex s4 0 0.170732\n",
"12 sex s5 0 0.170732\n",
"13 sex s6 0 0.170732\n",
"14 bmi bp 0 0.024390\n",
"15 bmi s1 0 0.390244\n",
"16 bmi s2 0 0.219512\n",
"17 bmi s3 0 0.024390\n",
"18 bmi s4 0 0.097561\n",
"19 bmi s5 0 0.024390\n",
"20 bmi s6 0 0.024390"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"pvals"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}