Contextualized Correlation Networks

Contextualized Correlation Networks#

Correlation networks summarize symmetric relationships between variable. We can make context-specific correlation networks with contextualized.ml .

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
X = np.random.normal(0, 1, size=(1000, 10))
C = np.random.uniform(-1, 1, size=(1000, 5))
%%capture
from contextualized.easy import ContextualizedCorrelationNetworks

ccn = ContextualizedCorrelationNetworks(encoder_type='ngam', num_archetypes=16, n_bootstraps=3)
ccn.fit(C, X, max_epochs=5)

# Get rho
rho = ccn.predict_correlation(C, individual_preds=False, squared=False)

# Get rho^2
rho_squared = ccn.predict_correlation(C, individual_preds=False, squared=True)
# To calculate confidence intervals, we can request individual predictions from the bootstrap models.
rho_squared = ccn.predict_correlation(C, individual_preds=True, squared=True)
# This prepends an axis to the output predictions.
print(rho_squared.shape)
(3, 1000, 10, 10)