Contextualized Bayesian Networks

Contextualized Bayesian Networks#

For more details, please see the NOTMAD preprint.

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from contextualized.dags.graph_utils import simulate_linear_sem

n = 1000
C = np.linspace(1, 2, n).reshape((n, 1))
W = np.zeros((4, 4, n, 1))
W[0, 1] = C - 2
W[2, 1] = C**2
W[3, 1] = C**3
W[3, 2] = C
W = np.squeeze(W)
W = np.transpose(W, (2, 0, 1))
X = np.zeros((n, 4))
for i, w in enumerate(W):
    x = simulate_linear_sem(w, 1, "uniform", noise_scale=0.1)[0]
    X[i] = x
%%capture
from contextualized.easy import ContextualizedBayesianNetworks

cbn = ContextualizedBayesianNetworks(
    encoder_type='mlp', num_archetypes=16,
    n_bootstraps=2, archetype_dag_loss_type="DAGMA", archetype_alpha=0.,
    sample_specific_dag_loss_type="DAGMA", sample_specific_alpha=1e-1,
    learning_rate=1e-3)
cbn.fit(C, X, max_epochs=10)
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name      | Type      | Params
----------------------------------------
0 | encoder   | MLP       | 1.8 K 
1 | explainer | Explainer | 256   
----------------------------------------
2.0 K     Trainable params
0         Non-trainable params
2.0 K     Total params
0.008     Total estimated model params size (MB)
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name      | Type      | Params
----------------------------------------
0 | encoder   | MLP       | 1.8 K 
1 | explainer | Explainer | 256   
----------------------------------------
2.0 K     Trainable params
0         Non-trainable params
2.0 K     Total params
0.008     Total estimated model params size (MB)
# We can measure Mean-Squared Error to measure likelihood of X under predicted networks.
mses = cbn.measure_mses(C, X)
print(np.mean(mses))
/opt/homebrew/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:430: PossibleUserWarning: The dataloader, predict_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 10 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
1.77282994611765
# Predict and visualize network
predicted_networks = cbn.predict_networks(C)

f, axarr = plt.subplots(1, 2)
axarr[0].imshow(predicted_networks[0])
axarr[1].imshow(W[0])
axarr[0].set_title("Predicted Network")
axarr[1].set_title("Ground-Truth Network")
Text(0.5, 1.0, 'Ground-Truth Network')
../_images/0f652bc6d3dff8ba4f58127a35df35db1773a826e3f71274f45e2913b968f2e7.png
# We can embed networks in lower-dimensional spaces to visualize distributions.
from contextualized.analysis.embeddings import plot_embedding_for_all_covars
import umap
import pandas as pd
low_dim_networks = umap.UMAP().fit_transform(predicted_networks.reshape((n, -1)))
plot_embedding_for_all_covars(low_dim_networks[:, :2], pd.DataFrame(C, columns=['C']),
                             xlabel="Network UMAP 1", ylabel="Network UMAP 2")
# In this case, there is only 1 context variable so the embeddings are not very interesting.
../_images/87af5cce85060273c6ebfb20758f3b6a45ace80ed3986959baf4ced74174d455.png

Acyclcity Regularizers#

Contextualized Bayesian Networks are available with two types of DAG losses:

These can be chosen for the archetypes and the sample-specific graphs independently by using the dag.loss_type parameter:

archetype_dag_loss_type="NOTMAD" or archetype_dag_loss_type="DAGMA"

and similarly

sample_specific_dag_loss_type="NOTMAD" or sample_specific_dag_loss_type="DAGMA".

NOTEARS has parameters:

  • alpha (float)

  • rho (float)

  • use_dynamic_alpha_rho (Boolean)

DAGMA has parameters:

  • alpha (strength, default 1e0)

  • s (max spectral radius, default 1)

Factor Graphs#

To improve scalability, we can include factor graphs (low-dimensional axes of network variation). This is controlled by the num_factors parameter. The default value of 0 turns off factor graphs and computes the network in full dimensionality.

We will explore this more deeply in the next notebook.