Correlation benchmark on dyngen data

Correlation benchmark on dyngen data#

Notebook benchmarks GRN inference using correlation on dyngen-generated data.

Library imports#

import numpy as np
import pandas as pd
import torch
from sklearn.metrics import roc_auc_score

import anndata as ad
import scvi

from rgv_tools import DATA_DIR
/home/icb/yifan.chen/miniconda3/envs/regvelo-py310/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_csv from `anndata` is deprecated. Import anndata.io.read_csv instead.
  warnings.warn(msg, FutureWarning)
/home/icb/yifan.chen/miniconda3/envs/regvelo-py310/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_loom from `anndata` is deprecated. Import anndata.io.read_loom instead.
  warnings.warn(msg, FutureWarning)
/home/icb/yifan.chen/miniconda3/envs/regvelo-py310/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_text from `anndata` is deprecated. Import anndata.io.read_text instead.
  warnings.warn(msg, FutureWarning)
/home/icb/yifan.chen/miniconda3/envs/regvelo-py310/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing CSCDataset from `anndata.experimental` is deprecated. Import anndata.abc.CSCDataset instead.
  warnings.warn(msg, FutureWarning)
/home/icb/yifan.chen/miniconda3/envs/regvelo-py310/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing CSRDataset from `anndata.experimental` is deprecated. Import anndata.abc.CSRDataset instead.
  warnings.warn(msg, FutureWarning)
/home/icb/yifan.chen/miniconda3/envs/regvelo-py310/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_elem from `anndata.experimental` is deprecated. Import anndata.io.read_elem instead.
  warnings.warn(msg, FutureWarning)

General settings#

scvi.settings.seed = 0
[rank: 0] Seed set to 0

Constants#

DATASET = "dyngen"
COMPLEXITY = "complexity_1"
SAVE_DATA = True
if SAVE_DATA:
    (DATA_DIR / DATASET / COMPLEXITY / "results").mkdir(parents=True, exist_ok=True)

Velocity pipeline#

grn_correlation = []

cnt = 0
for filename in (DATA_DIR / DATASET / COMPLEXITY / "processed").iterdir():
    torch.cuda.empty_cache()
    if filename.suffix != ".zarr":
        continue

    simulation_id = int(filename.stem.removeprefix("simulation_"))
    print(f"Run {cnt}, dataset {simulation_id}.")

    adata = ad.io.read_zarr(filename)
    grn_true = adata.uns["true_skeleton"]
    grn_sc_true = adata.uns["true_sc_grn"]

    grn_estimate = adata.to_df(layer="Ms").corr().values

    grn_auroc = []
    for cell_id in range(adata.n_obs):
        ground_truth = grn_sc_true[:, :, cell_id]

        if ground_truth.sum() > 0:
            ground_truth = ground_truth.T[np.array(grn_true.T) == 1]
            ground_truth[ground_truth != 0] = 1

            estimated = grn_estimate[np.array(grn_true.T) == 1]
            estimated = np.abs(estimated)

            number = min(10000, len(ground_truth))
            estimated, index = torch.topk(torch.tensor(estimated), number)

            grn_auroc.append(roc_auc_score(ground_truth[index], estimated))

    grn_correlation.append(np.mean(grn_auroc))
    cnt = cnt + 1
Run 0, dataset 29.
Run 1, dataset 14.
Run 2, dataset 24.
Run 3, dataset 28.
Run 4, dataset 6.
Run 5, dataset 21.
Run 6, dataset 15.
Run 7, dataset 9.
Run 8, dataset 12.
Run 9, dataset 19.
Run 10, dataset 4.
Run 11, dataset 13.
Run 12, dataset 2.
Run 13, dataset 16.
Run 14, dataset 1.
Run 15, dataset 18.
Run 16, dataset 5.
Run 17, dataset 10.
Run 18, dataset 8.
Run 19, dataset 11.
Run 20, dataset 27.
Run 21, dataset 23.
Run 22, dataset 17.
Run 23, dataset 30.
Run 24, dataset 22.
Run 25, dataset 25.
Run 26, dataset 20.
Run 27, dataset 7.
Run 28, dataset 3.
Run 29, dataset 26.

Data saving#

if SAVE_DATA:
    pd.DataFrame({"grn": grn_correlation}).to_parquet(
        path=DATA_DIR / DATASET / COMPLEXITY / "results" / "correlation_correlation.parquet"
    )