Driver ranking benchmark#

Benchmark driver ranking.

Library import#

import pandas as pd
import scipy
from sklearn.metrics import roc_auc_score

import matplotlib.pyplot as plt
import mplscience
import seaborn as sns

import cellrank as cr
import scanpy as sc
import scvelo as scv

from rgv_tools import DATA_DIR, FIG_DIR
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_csv from `anndata` is deprecated. Import anndata.io.read_csv instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_excel from `anndata` is deprecated. Import anndata.io.read_excel instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_hdf from `anndata` is deprecated. Import anndata.io.read_hdf instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_loom from `anndata` is deprecated. Import anndata.io.read_loom instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_mtx from `anndata` is deprecated. Import anndata.io.read_mtx instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_text from `anndata` is deprecated. Import anndata.io.read_text instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_umi_tools from `anndata` is deprecated. Import anndata.io.read_umi_tools instead.
  warnings.warn(msg, FutureWarning)

General settings#

plt.rcParams["svg.fonttype"] = "none"
sns.reset_defaults()
sns.reset_orig()
scv.settings.set_figure_params("scvelo", dpi_save=400, dpi=80, transparent=True, fontsize=14, color_map="viridis")

Constants#

DATASET = "zebrafish"
SAVE_FIGURES = False
if SAVE_FIGURES:
    (FIG_DIR / DATASET).mkdir(parents=True, exist_ok=True)
TERMINAL_STATES = [
    "mNC_head_mesenchymal",
    "mNC_arch2",
    "mNC_hox34",
    "Pigment",
]

Data loading#

adata = sc.read_h5ad(DATA_DIR / DATASET / "processed" / "adata_run_regvelo.h5ad")

Driver ranking analysis#

Correlating gene expression with cell fate probabilities#

vk = cr.kernels.VelocityKernel(adata)
vk.compute_transition_matrix()
ck = cr.kernels.ConnectivityKernel(adata).compute_transition_matrix()

kernel = 0.8 * vk + 0.2 * ck
100%|██████████| 697/697 [00:31<00:00, 22.18cell/s]
100%|██████████| 697/697 [00:00<00:00, 2299.24cell/s]
estimator = cr.estimators.GPCCA(kernel)
## evaluate the fate prob on original space
estimator.compute_macrostates(n_states=8, cluster_key="cell_type")
estimator.set_terminal_states(TERMINAL_STATES)
estimator.compute_fate_probabilities()
WARNING: Unable to import `petsc4py` or `slepc4py`. Using `method='brandts'`
WARNING: For `method='brandts'`, dense matrix is required. Densifying
WARNING: Unable to import petsc4py. For installation, please refer to: https://petsc4py.readthedocs.io/en/stable/install.html.
Defaulting to `'gmres'` solver.
100%|██████████| 4/4 [00:00<00:00, 32.14/s]
df = estimator.compute_lineage_drivers(cluster_key="cell_type")
fate_prob = pd.DataFrame(
    estimator.adata.obsm["lineages_fwd"], columns=estimator.adata.obsm["lineages_fwd"].names.tolist()
)

Correlating gene activity with cell fate probabilities#

GeneActivity = pd.read_csv(DATA_DIR / DATASET / "raw" / "GeneActivity.csv", index_col=0)
gene_symbol = [i.replace("_pos", "").replace("_neg", "") for i in GeneActivity.index.tolist()]
ind = [i in adata.var_names[adata.var["TF"]] for i in gene_symbol]
GeneActivity = GeneActivity.iloc[ind, :]
gene_symbol = [i.replace("_pos", "").replace("_neg", "") for i in GeneActivity.index.tolist()]


df_a = pd.DataFrame(columns=df.columns.tolist())
for i in fate_prob.columns.tolist():
    cor = []
    pval = []
    for g in GeneActivity.index.tolist():
        res = scipy.stats.pearsonr(GeneActivity.loc[g, :], fate_prob.loc[:, i])
        cor.append(res[0])
        pval.append(res[1])
    df_a[i + "_corr"] = cor
    df_a[i + "_pval"] = pval
df_a.index = GeneActivity.index.tolist()

Predict known drivers#

driver_head_mes = ["nr2f5", "nr2f2", "sox9b", "twist1a", "twist1b"]
driver_pigment = ["sox10", "mitfa", "tfec", "bhlhe40", "tfap2b", "tfap2a"]

mNC head mesenchymal#

score_head_mes = []
label_ga = [1 if i in driver_head_mes else 0 for i in df_a.index.tolist()]
label_gex = [1 if i in driver_head_mes else 0 for i in df.index.tolist()]
for nrun in range(3):
    coef_name = "coef_" + str(nrun)
    coef = pd.read_csv(DATA_DIR / DATASET / "results" / coef_name, index_col=0)

    label_rgv = [1 if i in driver_head_mes else 0 for i in coef.index.tolist()]
    score_head_mes.append(roc_auc_score(label_rgv, coef["mNC_head_mesenchymal"]))

score_head_mes.append(roc_auc_score(label_ga, df_a["mNC_head_mesenchymal_corr"]))
score_head_mes.append(roc_auc_score(label_gex, df["mNC_head_mesenchymal_corr"]))

Pigment#

score_pigment = []

for nrun in range(3):
    coef_name = "coef_" + str(nrun)
    coef = pd.read_csv(DATA_DIR / DATASET / "results" / coef_name, index_col=0)

    label_rgv = [1 if i in driver_pigment else 0 for i in coef.index.tolist()]
    score_pigment.append(roc_auc_score(label_rgv, coef["Pigment"]))

label_ga = [1 if i in driver_pigment else 0 for i in df_a.index.tolist()]
label_gex = [1 if i in driver_pigment else 0 for i in df.index.tolist()]
score_pigment.append(roc_auc_score(label_ga, df_a["Pigment_corr"]))
score_pigment.append(roc_auc_score(label_gex, df["Pigment_corr"]))

Plot results#

dat = pd.DataFrame({"AUROC": score_head_mes, "method": ["RegVelo (PS)"] * 3 + ["CR (GA)"] + ["CR (GEX)"]})
dat2 = pd.DataFrame({"AUROC": score_pigment, "method": ["RegVelo (PS)"] * 3 + ["CR (GA)"] + ["CR (GEX)"]})
dat["Lineage"] = "mNC head mesenchymal"
dat2["Lineage"] = "Pigment"
df = pd.concat([dat, dat2], axis=0)
with mplscience.style_context():
    sns.set_style(style="whitegrid")
    fig, ax = plt.subplots(figsize=(4, 5))

    # Plot the barplot without error bars
    sns.barplot(data=df, y="Lineage", x="AUROC", hue="method", ax=ax, ci=None)

    # Add jittered dots
    sns.stripplot(data=df, y="Lineage", x="AUROC", hue="method", dodge=True, color="black", ax=ax, jitter=True)

    # Remove the duplicate legend
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles[3:6], labels[3:6], bbox_to_anchor=(0.5, -0.1), loc="upper center", ncol=2)

    # Customize labels and other settings
    ax.set(ylabel="", xlabel="AUROC")
    ax.set_xlabel(xlabel="AUROC", fontsize=13)
    ax.set_xlim(0.5, 0.98)

    if SAVE_FIGURES:
        plt.savefig(FIG_DIR / DATASET / "AUROC_ranking.svg", format="svg", transparent=True, bbox_inches="tight")
    plt.show()
../_images/337618a0dacd4eafe87c988a061ab6554e607cd5528ce78bb5e7e4c5fba4b251.png