Run RegVelo on mHSPC datasets#
Library import#
from itertools import permutations, product
import numpy as np
import pandas as pd
import torch
from sklearn.metrics import roc_auc_score
import scanpy as sc
import scvi
from regvelo import REGVELOVI
from rgv_tools import DATA_DIR
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_csv from `anndata` is deprecated. Import anndata.io.read_csv instead.
warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_excel from `anndata` is deprecated. Import anndata.io.read_excel instead.
warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_hdf from `anndata` is deprecated. Import anndata.io.read_hdf instead.
warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_loom from `anndata` is deprecated. Import anndata.io.read_loom instead.
warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_mtx from `anndata` is deprecated. Import anndata.io.read_mtx instead.
warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_text from `anndata` is deprecated. Import anndata.io.read_text instead.
warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_umi_tools from `anndata` is deprecated. Import anndata.io.read_umi_tools instead.
warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_csv from `anndata` is deprecated. Import anndata.io.read_csv instead.
warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_loom from `anndata` is deprecated. Import anndata.io.read_loom instead.
warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_text from `anndata` is deprecated. Import anndata.io.read_text instead.
warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing CSCDataset from `anndata.experimental` is deprecated. Import anndata.abc.CSCDataset instead.
warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing CSRDataset from `anndata.experimental` is deprecated. Import anndata.abc.CSRDataset instead.
warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_elem from `anndata.experimental` is deprecated. Import anndata.io.read_elem instead.
warnings.warn(msg, FutureWarning)
General setting#
scvi.settings.seed = 0
[rank: 0] Seed set to 0
Constants#
DATASET = "mHSPC"
SAVE_DATA = True
if SAVE_DATA:
(DATA_DIR / DATASET / "results").mkdir(parents=True, exist_ok=True)
Define functions#
def unsigned(true_edges: pd.DataFrame, pred_edges: pd.DataFrame, type: str = "alledges") -> tuple[float, float, float]:
"""Compare true vs predicted edges (unsigned) and compute precision/recall metrics.
Returns
-------
tuple: (eprec, erec, eprec_ratio)
"""
true_edges_copy = true_edges.copy()
pred_edges_copy = pred_edges.copy()
# Drop self-edges and duplicates
true_edges_copy = true_edges_copy.loc[(true_edges_copy["Gene1"] != true_edges_copy["Gene2"])]
true_edges_copy.drop_duplicates(keep="first", inplace=True)
true_edges_copy.reset_index(drop=True, inplace=True)
pred_edges_copy = pred_edges_copy.loc[(pred_edges_copy["Gene1"] != pred_edges_copy["Gene2"])]
pred_edges_copy.drop_duplicates(keep="first", inplace=True)
pred_edges_copy.reset_index(drop=True, inplace=True)
# Get a list of all possible TF to gene interactions
unique_nodes = np.unique(true_edges_copy.loc[:, ["Gene1", "Gene2"]])
possible_edges_all = set(product(set(true_edges_copy.Gene1), set(unique_nodes)))
# Get a list of all possible interactions
possible_edges_no_self = set(permutations(unique_nodes, r=2))
# Find intersection of above lists to ignore self edges
possible_edges = possible_edges_all.intersection(possible_edges_no_self)
true_edges_dict = {"|".join(p): 0 for p in possible_edges}
true_edges_str = true_edges_copy["Gene1"] + "|" + true_edges_copy["Gene2"]
true_edges_str = true_edges_str[true_edges_str.isin(true_edges_dict)]
n_edges = len(true_edges_str)
pred_edges_copy["Edges"] = pred_edges_copy["Gene1"] + "|" + pred_edges_copy["Gene2"]
pred_edges_copy = pred_edges_copy[pred_edges_copy["Edges"].isin(true_edges_dict)]
pred_edges_copy.copy()
if not pred_edges_copy.shape[0] == 0:
pred_edges_copy.loc[:, "EdgeWeight"] = pred_edges_copy.EdgeWeight.round(6).abs()
pred_edges_copy.sort_values(by="EdgeWeight", ascending=False, inplace=True)
maxk = min(pred_edges_copy.shape[0], n_edges)
edge_weight_topk = pred_edges_copy.iloc[maxk - 1].EdgeWeight
nnz_min = np.nanmin(pred_edges_copy.EdgeWeight.replace(0, np.nan).values)
best_val = max(nnz_min, edge_weight_topk)
newDF = pred_edges_copy.loc[(pred_edges_copy["EdgeWeight"] >= best_val)]
rank = set(newDF["Gene1"] + "|" + newDF["Gene2"])
intersectionSet = rank.intersection(true_edges_str)
eprec = len(intersectionSet) / len(rank)
erec = len(intersectionSet) / len(true_edges_str)
random_eprec = n_edges / len(true_edges_dict)
eprec_ratio = eprec / random_eprec
else:
eprec = 1.0
erec = 1.0
eprec_ratio = 1.0
print("EPR: " + str(eprec_ratio))
return eprec, erec, eprec_ratio
def calculate_auroc(inferred_scores_df: pd.DataFrame, ground_truth_df: pd.DataFrame) -> float:
"""Calculate AUROC comparing inferred edge scores against ground truth.
Returns
-------
float: AUROC score.
"""
ground_truth_set = set(zip(ground_truth_df["Gene1"], ground_truth_df["Gene2"]))
inferred_scores_df["label"] = inferred_scores_df.apply(
lambda row: (row["Gene1"], row["Gene2"]) in ground_truth_set, axis=1
).astype(int)
y_true = inferred_scores_df["label"]
y_scores = inferred_scores_df["EdgeWeight"]
auroc = roc_auc_score(y_true, y_scores)
return auroc
Data loading#
adata = sc.read_h5ad(DATA_DIR / DATASET / "processed" / "mHSC_ExpressionData.h5ad")
TF = pd.read_csv(DATA_DIR / DATASET / "raw" / "mouse-tfs.csv")
TF = [i[0].upper() + i[1:].lower() for i in TF["TF"].tolist()]
TF = np.array(TF)[[i in adata.var_names for i in TF]]
TF
array(['Ankrd22', 'Ankrd7', 'Arntl2', 'Batf3', 'Bcl11b', 'Bmp6', 'Btg2',
'Chd7', 'Ciita', 'Cnot6l', 'Creb5', 'Csrp3', 'Ctr9', 'Ebf1',
'Egr2', 'Esr1', 'Ets1', 'Etv6', 'Eya1', 'Eya2', 'Eya4', 'Fos',
'Fosb', 'Gata1', 'Gata2', 'Gata3', 'Gfi1', 'Gfi1b', 'Glis3', 'Hlf',
'Hoxa9', 'Hspb1', 'Id2', 'Id3', 'Ifi204', 'Ikzf1', 'Ikzf3', 'Il10',
'Irf4', 'Irf8', 'Isl1', 'Klf1', 'Klf6', 'Kpna2', 'Ldb2', 'Lef1',
'Lmo4', 'Maf', 'Mapk11', 'Mecom', 'Mef2c', 'Meis1', 'Mllt3',
'Mmp9', 'Myb', 'Myc', 'Mycn', 'Nfatc2', 'Nfia', 'Nfil3', 'Nfkbiz',
'Nr1h4', 'Pax5', 'Pgr', 'Pou2af1', 'Prdm1', 'Rad54b', 'Rapgef3',
'Relb', 'Rora', 'Runx1t1', 'Satb1', 'Setbp1', 'Sla2', 'Smarca4',
'Spib', 'Stat3', 'Stat4', 'Tox2', 'Trib3', 'Trps1', 'Xbp1',
'Zbtb16', 'Zbtb20', 'Zbtb38', 'Zfp354a'], dtype='<U12')
Velocity pipeline#
W = torch.ones([adata.n_vars, adata.n_vars])
vae_list = []
for _nrun in range(3):
REGVELOVI.setup_anndata(adata, spliced_layer="Ms", unspliced_layer="Mu")
vae = REGVELOVI(adata, W=W, regulators=TF)
vae.train()
vae_list.append(vae)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/weixu.wang/miniconda3/envs/regvelo_test/li ...
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/weixu.wang/miniconda3/envs/regvelo_test/li ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/weixu.wang/miniconda3/envs/regvelo_test/li ...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=5` in the `DataLoader` to improve performance.
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=10). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=5` in the `DataLoader` to improve performance.
Monitored metric elbo_validation did not improve in the last 45 records. Best score: -2721.169. Signaling Trainer to stop.
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/weixu.wang/miniconda3/envs/regvelo_test/li ...
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/weixu.wang/miniconda3/envs/regvelo_test/li ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/weixu.wang/miniconda3/envs/regvelo_test/li ...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=5` in the `DataLoader` to improve performance.
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=10). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=5` in the `DataLoader` to improve performance.
Monitored metric elbo_validation did not improve in the last 45 records. Best score: -2762.707. Signaling Trainer to stop.
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/weixu.wang/miniconda3/envs/regvelo_test/li ...
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/weixu.wang/miniconda3/envs/regvelo_test/li ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/weixu.wang/miniconda3/envs/regvelo_test/li ...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=5` in the `DataLoader` to improve performance.
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=10). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=5` in the `DataLoader` to improve performance.
IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.
Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)
EPR_score = []
AUC_score = []
for nrun in range(3):
vae = vae_list[nrun]
grn_estimate = vae.module.v_encoder.GRN_Jacobian(torch.tensor(adata.layers["Ms"]).to("cuda:0"))
grn_estimate = grn_estimate.cpu().detach().numpy()
grn_estimate = np.abs(grn_estimate)
grn_estimate = pd.DataFrame(grn_estimate, index=adata.var_names.tolist(), columns=adata.var_names.tolist())
grn_estimate = grn_estimate.loc[:, TF].copy()
grn = pd.DataFrame(grn_estimate.stack()).reset_index()
grn.columns = ["Gene2", "Gene1", "EdgeWeight"]
result = grn[["Gene1", "Gene2", "EdgeWeight"]].sort_values(by="EdgeWeight", ascending=False).reset_index(drop=True)
gt = pd.read_csv(DATA_DIR / DATASET / "raw" / "mHSC-ChIP-seq-network.csv")
gt["Gene1"] = [i[0].upper() + i[1:].lower() for i in gt["Gene1"].tolist()]
gt["Gene2"] = [i[0].upper() + i[1:].lower() for i in gt["Gene2"].tolist()]
gt = gt.loc[[i in TF for i in gt["Gene1"]], :]
gt = gt.loc[[i in adata.var_names for i in gt["Gene2"]], :]
_, _, epr = unsigned(gt, result)
EPR_score.append(epr)
AUC_score.append(calculate_auroc(result, gt))
EPR: 1.1386958157171931
EPR: 1.145426793837005
EPR: 1.1243054487024222
Results#
result_df = pd.DataFrame({"EPR": EPR_score, "AUC": AUC_score, "Method": ["regvelo"] * 3})
if SAVE_DATA:
result_df.to_csv(DATA_DIR / DATASET / "results" / "GRN_benchmark_rgv.csv")