GRN calibration evaluation#

Library import#

import copy

from inferelator.postprocessing.model_metrics import RankSummaryPR, RankSummingMetric

import numpy as np
import pandas as pd
import sklearn

from matplotlib import pyplot as plt

import scanpy as sc
import scvi
from regvelo import REGVELOVI

from rgv_tools import DATA_DIR, FIG_DIR
/home/icb/weixu.wang/miniconda3/envs/regvelo-py310/lib/python3.10/site-packages/anndata/utils.py:434: FutureWarning: Importing read_csv from `anndata` is deprecated. Import anndata.io.read_csv instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo-py310/lib/python3.10/site-packages/anndata/utils.py:434: FutureWarning: Importing read_excel from `anndata` is deprecated. Import anndata.io.read_excel instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo-py310/lib/python3.10/site-packages/anndata/utils.py:434: FutureWarning: Importing read_hdf from `anndata` is deprecated. Import anndata.io.read_hdf instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo-py310/lib/python3.10/site-packages/anndata/utils.py:434: FutureWarning: Importing read_loom from `anndata` is deprecated. Import anndata.io.read_loom instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo-py310/lib/python3.10/site-packages/anndata/utils.py:434: FutureWarning: Importing read_mtx from `anndata` is deprecated. Import anndata.io.read_mtx instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo-py310/lib/python3.10/site-packages/anndata/utils.py:434: FutureWarning: Importing read_text from `anndata` is deprecated. Import anndata.io.read_text instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo-py310/lib/python3.10/site-packages/anndata/utils.py:434: FutureWarning: Importing read_umi_tools from `anndata` is deprecated. Import anndata.io.read_umi_tools instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo-py310/lib/python3.10/site-packages/louvain/__init__.py:54: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
  from pkg_resources import get_distribution, DistributionNotFound
/home/icb/weixu.wang/miniconda3/envs/regvelo-py310/lib/python3.10/site-packages/anndata/utils.py:434: FutureWarning: Importing read_csv from `anndata` is deprecated. Import anndata.io.read_csv instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo-py310/lib/python3.10/site-packages/anndata/utils.py:434: FutureWarning: Importing read_loom from `anndata` is deprecated. Import anndata.io.read_loom instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo-py310/lib/python3.10/site-packages/anndata/utils.py:434: FutureWarning: Importing read_text from `anndata` is deprecated. Import anndata.io.read_text instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo-py310/lib/python3.10/site-packages/anndata/utils.py:434: FutureWarning: Importing CSCDataset from `anndata.experimental` is deprecated. Import anndata.abc.CSCDataset instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo-py310/lib/python3.10/site-packages/anndata/utils.py:434: FutureWarning: Importing CSRDataset from `anndata.experimental` is deprecated. Import anndata.abc.CSRDataset instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo-py310/lib/python3.10/site-packages/anndata/utils.py:434: FutureWarning: Importing read_elem from `anndata.experimental` is deprecated. Import anndata.io.read_elem instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo-py310/lib/python3.10/site-packages/jaxlib/plugin_support.py:71: RuntimeWarning: JAX plugin jax_cuda12_plugin version 0.5.0 is installed, but it is not compatible with the installed jaxlib version 0.6.2, so it will not be used.
  warnings.warn(

General setting#

scvi.settings.seed = 0
[rank: 0] Seed set to 0

Constants#

DATASET = "mHSPC"
SAVE_DATA = True
SAVE_FIGURE = True
if SAVE_DATA:
    (DATA_DIR / DATASET / "results").mkdir(parents=True, exist_ok=True)
if SAVE_FIGURE:
    (FIG_DIR / DATASET).mkdir(parents=True, exist_ok=True)

Define functions#

We followed the GRN calibration evaluation workflow provided by PMF-GRN, please check nyu-dl/pmf-grn

RankSummingMetricCopy = copy.deepcopy(RankSummingMetric)


def get_calibration_score(to_eval, gold_standard, filter_method="overlap", method="auroc"):
    """Compute a calibration score comparing predictions to a gold standard.

    Parameters
    ----------
    to_eval : list or DataFrame
        Predictions or scores to evaluate.
    gold_standard : list or DataFrame
        True labels for evaluation.
    filter_method : str, optional
        Method for filtering data before scoring (default is "overlap").
    method : str, optional
        Scoring metric: "auroc" or "auprc" (default is "auroc").

    Returns
    -------
    float
        Area under the ROC or PR curve as the calibration score.
    """
    metrics = RankSummingMetricCopy([to_eval], gold_standard, filter_method)

    if method == "auprc":
        data = RankSummaryPR.calculate_precision_recall(metrics.filtered_data)
        auc = RankSummaryPR.calculate_aupr(data)
    elif method == "auroc":
        fpr, tpr, thresholds = sklearn.metrics.roc_curve(
            metrics.filtered_data["gold_standard"], metrics.filtered_data["combined_confidences"]
        )
        auc = sklearn.metrics.auc(fpr, tpr)
    return auc

Data loading#

adata = sc.read_h5ad(DATA_DIR / DATASET / "processed" / "mHSC_ExpressionData.h5ad")
TF = pd.read_csv(DATA_DIR / DATASET / "raw" / "mouse-tfs.csv")
TF = [i[0].upper() + i[1:].lower() for i in TF["TF"].tolist()]
TF = np.array(TF)[[i in adata.var_names for i in TF]]
TF
array(['Ankrd22', 'Ankrd7', 'Arntl2', 'Batf3', 'Bcl11b', 'Bmp6', 'Btg2',
       'Chd7', 'Ciita', 'Cnot6l', 'Creb5', 'Csrp3', 'Ctr9', 'Ebf1',
       'Egr2', 'Esr1', 'Ets1', 'Etv6', 'Eya1', 'Eya2', 'Eya4', 'Fos',
       'Fosb', 'Gata1', 'Gata2', 'Gata3', 'Gfi1', 'Gfi1b', 'Glis3', 'Hlf',
       'Hoxa9', 'Hspb1', 'Id2', 'Id3', 'Ifi204', 'Ikzf1', 'Ikzf3', 'Il10',
       'Irf4', 'Irf8', 'Isl1', 'Klf1', 'Klf6', 'Kpna2', 'Ldb2', 'Lef1',
       'Lmo4', 'Maf', 'Mapk11', 'Mecom', 'Mef2c', 'Meis1', 'Mllt3',
       'Mmp9', 'Myb', 'Myc', 'Mycn', 'Nfatc2', 'Nfia', 'Nfil3', 'Nfkbiz',
       'Nr1h4', 'Pax5', 'Pgr', 'Pou2af1', 'Prdm1', 'Rad54b', 'Rapgef3',
       'Relb', 'Rora', 'Runx1t1', 'Satb1', 'Setbp1', 'Sla2', 'Smarca4',
       'Spib', 'Stat3', 'Stat4', 'Tox2', 'Trib3', 'Trps1', 'Xbp1',
       'Zbtb16', 'Zbtb20', 'Zbtb38', 'Zfp354a'], dtype='<U12')

Load ground truth (ChIP-seq)#

gt = pd.read_csv(DATA_DIR / DATASET / "raw" / "mHSC-ChIP-seq-network.csv")
gt["Gene1"] = [i[0].upper() + i[1:].lower() for i in gt["Gene1"].tolist()]
gt["Gene2"] = [i[0].upper() + i[1:].lower() for i in gt["Gene2"].tolist()]
gt = gt.loc[[i in TF for i in gt["Gene1"]], :]
gt = gt.loc[[i in adata.var_names for i in gt["Gene2"]], :]

p_class = pd.DataFrame(0, index=adata.var_names, columns=TF)

for _, row in gt.iterrows():
    reg = row["Gene1"]
    tar = row["Gene2"]
    if tar in p_class.index and reg in p_class.columns:
        p_class.loc[tar, reg] = 1

Running GRN to calculate caliberate estimate of error#

REGVELOVI.setup_anndata(adata, spliced_layer="Ms", unspliced_layer="Mu")

for nrun in range(10):
    vae = REGVELOVI(adata, regulators=TF)
    vae.train()

    vae.save(DATA_DIR / DATASET / "processed" / f"hsc_model_run_{nrun}")
grns = []
for nrun in range(10):
    path = DATA_DIR / DATASET / "processed" / f"hsc_model_run_{nrun}"
    vae = REGVELOVI.load(path, adata)

    w = vae.module.v_encoder.fc1.weight.cpu().detach()
    w = pd.DataFrame(w, index=adata.var_names, columns=adata.var_names)
    w = w.loc[:, TF]
    grns.append(w)
INFO     File                                                                                                      
         /lustre/groups/ml01/workspace/weixu.wang/regvelo_revision/parameter_tuning/caliaberate_error/hsc_model_run
         _0/model.pt already downloaded                                                                            
/home/icb/weixu.wang/miniconda3/envs/regvelo-py310/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/weixu.wang/miniconda3/envs/regvelo-py310/l ...
INFO     File                                                                                                      
         /lustre/groups/ml01/workspace/weixu.wang/regvelo_revision/parameter_tuning/caliaberate_error/hsc_model_run
         _1/model.pt already downloaded                                                                            
/home/icb/weixu.wang/miniconda3/envs/regvelo-py310/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/weixu.wang/miniconda3/envs/regvelo-py310/l ...
INFO     File                                                                                                      
         /lustre/groups/ml01/workspace/weixu.wang/regvelo_revision/parameter_tuning/caliaberate_error/hsc_model_run
         _2/model.pt already downloaded                                                                            
/home/icb/weixu.wang/miniconda3/envs/regvelo-py310/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/weixu.wang/miniconda3/envs/regvelo-py310/l ...
INFO     File                                                                                                      
         /lustre/groups/ml01/workspace/weixu.wang/regvelo_revision/parameter_tuning/caliaberate_error/hsc_model_run
         _3/model.pt already downloaded                                                                            
/home/icb/weixu.wang/miniconda3/envs/regvelo-py310/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/weixu.wang/miniconda3/envs/regvelo-py310/l ...
INFO     File                                                                                                      
         /lustre/groups/ml01/workspace/weixu.wang/regvelo_revision/parameter_tuning/caliaberate_error/hsc_model_run
         _4/model.pt already downloaded                                                                            
/home/icb/weixu.wang/miniconda3/envs/regvelo-py310/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/weixu.wang/miniconda3/envs/regvelo-py310/l ...
INFO     File                                                                                                      
         /lustre/groups/ml01/workspace/weixu.wang/regvelo_revision/parameter_tuning/caliaberate_error/hsc_model_run
         _5/model.pt already downloaded                                                                            
/home/icb/weixu.wang/miniconda3/envs/regvelo-py310/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/weixu.wang/miniconda3/envs/regvelo-py310/l ...
INFO     File                                                                                                      
         /lustre/groups/ml01/workspace/weixu.wang/regvelo_revision/parameter_tuning/caliaberate_error/hsc_model_run
         _6/model.pt already downloaded                                                                            
/home/icb/weixu.wang/miniconda3/envs/regvelo-py310/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/weixu.wang/miniconda3/envs/regvelo-py310/l ...
INFO     File                                                                                                      
         /lustre/groups/ml01/workspace/weixu.wang/regvelo_revision/parameter_tuning/caliaberate_error/hsc_model_run
         _7/model.pt already downloaded                                                                            
/home/icb/weixu.wang/miniconda3/envs/regvelo-py310/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/weixu.wang/miniconda3/envs/regvelo-py310/l ...
INFO     File                                                                                                      
         /lustre/groups/ml01/workspace/weixu.wang/regvelo_revision/parameter_tuning/caliaberate_error/hsc_model_run
         _8/model.pt already downloaded                                                                            
/home/icb/weixu.wang/miniconda3/envs/regvelo-py310/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/weixu.wang/miniconda3/envs/regvelo-py310/l ...
INFO     File                                                                                                      
         /lustre/groups/ml01/workspace/weixu.wang/regvelo_revision/parameter_tuning/caliaberate_error/hsc_model_run
         _9/model.pt already downloaded                                                                            
/home/icb/weixu.wang/miniconda3/envs/regvelo-py310/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/weixu.wang/miniconda3/envs/regvelo-py310/l ...
stacked = np.stack(grns)
median = np.median(stacked, axis=0)
epsilon = np.percentile(np.abs(median)[median != 0], 10)
edge_variance = (np.median(np.abs(stacked - median), axis=0) + epsilon) / (np.abs(median) + epsilon)
edge_variance = pd.DataFrame(edge_variance, index=p_class.index, columns=p_class.columns)
grn = pd.DataFrame(median, index=p_class.index, columns=p_class.columns)
percentile_values = np.percentile(edge_variance, np.arange(1, 11) * 10)
auprcs = []
for i in range(2, len(percentile_values)):
    to_eval = copy.deepcopy(np.abs(grn))
    to_eval[edge_variance > percentile_values[i]] = np.nan
    auprcs.append(get_calibration_score(to_eval, p_class, filter_method="overlap", method="auprc"))
2025-10-25 16:06:43 1 GS: 8940 edges; Confidences: 28250 edges
2025-10-25 16:06:44 1 GS: 8940 edges; Confidences: 37667 edges
2025-10-25 16:06:44 1 GS: 8940 edges; Confidences: 47084 edges
2025-10-25 16:06:44 1 GS: 8940 edges; Confidences: 56501 edges
2025-10-25 16:06:44 1 GS: 8940 edges; Confidences: 65918 edges
2025-10-25 16:06:44 1 GS: 8940 edges; Confidences: 75250 edges
2025-10-25 16:06:44 1 GS: 8940 edges; Confidences: 84667 edges
2025-10-25 16:06:44 1 GS: 8940 edges; Confidences: 94084 edges
auprcs
[0.17072510990303702,
 0.16163773408491552,
 0.154657775559322,
 0.14891162480176293,
 0.14441520241636058,
 0.14058914165962355,
 0.13723783365027428,
 0.1336906127397129]
plt.rcParams["svg.fonttype"] = "none"
plt.plot(np.arange(1, len(auprcs) + 1) * 100 / len(auprcs), np.array(auprcs))
plt.xlim(0, 100)
plt.grid()
plt.xlabel("Percentile Cutoff", fontsize=18)
plt.ylabel("AUPRC", fontsize=18)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.tight_layout()

if SAVE_FIGURE:
    plt.savefig(FIG_DIR / DATASET / "GRN_calibration_auprc.svg", format="svg", transparent=True, bbox_inches="tight")
../_images/17d66cbb7bc8fa06c5bc3378e0f9a54a6b1f0023ee5b7ca673a31aaf92b14d35.png
auroc = []
for i in range(2, len(percentile_values)):
    to_eval = copy.deepcopy(np.abs(grn))
    to_eval[edge_variance > percentile_values[i]] = np.nan
    auroc.append(get_calibration_score(to_eval, p_class, filter_method="overlap", method="auroc"))
2025-10-25 16:06:50 1 GS: 8940 edges; Confidences: 28250 edges
2025-10-25 16:06:50 1 GS: 8940 edges; Confidences: 37667 edges
2025-10-25 16:06:50 1 GS: 8940 edges; Confidences: 47084 edges
2025-10-25 16:06:50 1 GS: 8940 edges; Confidences: 56501 edges
2025-10-25 16:06:50 1 GS: 8940 edges; Confidences: 65918 edges
2025-10-25 16:06:51 1 GS: 8940 edges; Confidences: 75250 edges
2025-10-25 16:06:51 1 GS: 8940 edges; Confidences: 84667 edges
2025-10-25 16:06:51 1 GS: 8940 edges; Confidences: 94084 edges
plt.plot(np.arange(1, len(auroc) + 1) * 100 / len(auroc), auroc)
plt.xlim(0, 100)
plt.grid()
plt.xlabel("Percentile Cutoff", fontsize=18)
plt.ylabel("AUROC", fontsize=18)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
plt.tight_layout()

if SAVE_FIGURE:
    plt.savefig(FIG_DIR / DATASET / "GRN_calibration_roc.svg", format="svg", transparent=True, bbox_inches="tight")
../_images/43d97cfee2bef39975ee1e53c97a3a09f29d5a6f1aaa01837fab95353d73585b.png