Evaluate latent time prediction via redefining global latent time

Evaluate latent time prediction via redefining global latent time#

Before run this notebooks, please run 2_regvelo.ipynb,2_velovi.ipynb and 2_velovi_td.ipynb for three times to save the resulted anndata accordingly.

Library imports#

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import mplscience
import seaborn as sns

import anndata as ad
import scanpy as sc

from rgv_tools import DATA_DIR, FIG_DIR
from rgv_tools.benchmarking import get_time_correlation
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_csv from `anndata` is deprecated. Import anndata.io.read_csv instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_excel from `anndata` is deprecated. Import anndata.io.read_excel instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_hdf from `anndata` is deprecated. Import anndata.io.read_hdf instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_loom from `anndata` is deprecated. Import anndata.io.read_loom instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_mtx from `anndata` is deprecated. Import anndata.io.read_mtx instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_text from `anndata` is deprecated. Import anndata.io.read_text instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_umi_tools from `anndata` is deprecated. Import anndata.io.read_umi_tools instead.
  warnings.warn(msg, FutureWarning)

Constants#

DATASET1 = "cell_cycle"
DATASET2 = "cell_cycle_rpe1"
SAVE_FIGURES = True
if SAVE_FIGURES:
    (FIG_DIR / DATASET2).mkdir(parents=True, exist_ok=True)

FIGURE_FORMATE = "svg"
VELO_METHODS = ["regvelo", "velovi", "velovi_td", "scvelo"]
custom_palette = {
    "U2OS": "#0072B2",  # Elegant blue (colorblind-friendly)
    "RPE1": "#E69F00",  # Gold-orange (also friendly and high contrast)
}
METHOD = [
    "rgv",
    "rgv_rep2",
    "rgv_rep3",
    "scv",
    "velovi",
    "velovi_rep2",
    "velovi_rep3",
    "velovi_td",
    "velovi_td_rep2",
    "velovi_td_rep3",
]

Define functions#

def GlobalTimeDPTordering(adata: ad.AnnData):
    """Compute latent time using graph-based diffusion pseudotime (DPT) ordering.

    This function temporarily replaces the expression matrix with the
    fitted values, computes neighbors, diffusion maps, and diffusion
    pseudotime, and then restores the original expression matrix.

    Parameters
    ----------
    adata : ad.AnnData
        Annotated data matrix.

    Returns
    -------
    ad.AnnData
        The AnnData object with `latent_time` stored in `adata.obs`.
    """
    adata.layers["GEX"] = adata.X.copy()
    adata.X = adata.layers["fit_t"].copy()

    sc.pp.neighbors(adata, n_neighbors=30, metric="correlation", use_rep="X")
    sc.tl.diffmap(adata)

    adata.uns["iroot"] = np.argmin(np.mean(adata.X, axis=1))
    sc.tl.dpt(adata)

    adata.obs["latent_time"] = adata.obs["dpt_pseudotime"].copy()

    adata.X = adata.layers["GEX"].copy()
    del adata.layers["GEX"]

    return adata

Use latent time layer to calculate neighbor graph and define global latent time#

corr_time = []
for i in METHOD:
    adata = sc.read_h5ad(DATA_DIR / DATASET1 / "processed" / f"adata_{i}.h5ad")
    adata = GlobalTimeDPTordering(adata)
    corr_time.append(
        get_time_correlation(ground_truth=adata.obs["cell_cycle_position"], estimated=adata.obs["latent_time"])
    )
time_df = pd.DataFrame(
    {"method": ["regvelo"] * 3 + ["scvelo"] + ["velovi"] * 3 + ["velovi_td"] * 3, "spearman_corr": corr_time}
)
time_df_rpe1 = time_df.copy()
time_df_rpe1
method spearman_corr
0 regvelo 0.455059
1 regvelo 0.481358
2 regvelo 0.453335
3 scvelo 0.332056
4 velovi -0.341981
5 velovi 0.645073
6 velovi -0.381067
7 velovi_td -0.398513
8 velovi_td 0.197365
9 velovi_td 0.305876
corr_time = []
for i in METHOD:
    adata = sc.read_h5ad(DATA_DIR / DATASET2 / "processed" / f"adata_{i}.h5ad")
    adata = GlobalTimeDPTordering(adata)
    corr_time.append(get_time_correlation(ground_truth=adata.obs["fucci_time"], estimated=adata.obs["latent_time"]))
time_df = pd.DataFrame(
    {"method": ["regvelo"] * 3 + ["scvelo"] + ["velovi"] * 3 + ["velovi_td"] * 3, "spearman_corr": corr_time}
)
time_df_u2os = time_df.copy()
time_df_u2os
method spearman_corr
0 regvelo 0.325817
1 regvelo 0.338398
2 regvelo 0.334311
3 scvelo 0.153472
4 velovi -0.786511
5 velovi -0.205858
6 velovi -0.222118
7 velovi_td 0.223798
8 velovi_td 0.660562
9 velovi_td 0.288485
time_df_u2os["Dataset"] = "U2OS"
time_df_rpe1["Dataset"] = "RPE1"
time_df = pd.concat([time_df_u2os, time_df_rpe1])
# Plot using the refined Cell-style palette
with mplscience.style_context():
    sns.set_style("whitegrid")
    fig, ax = plt.subplots(figsize=(3, 3), sharey=True)

    sns.barplot(
        y="method",
        x="spearman_corr",
        data=time_df,
        hue="Dataset",
        capsize=0.1,
        order=VELO_METHODS,
        ax=ax,
        palette=custom_palette,
    )

    sns.stripplot(
        y="method",
        x="spearman_corr",
        data=time_df,
        hue="Dataset",
        dodge=True,
        order=VELO_METHODS,
        size=3,
        jitter=True,
        color="black",
        ax=ax,
        linewidth=0,
    )

    # Remove duplicate legends
    handles, labels = ax.get_legend_handles_labels()
    unique = dict(zip(labels, handles))  # remove duplicates
    ax.legend(unique.values(), unique.keys(), title="Dataset", frameon=False)

    ax.set_xlabel("spearman_corr", fontsize=14)
    ax.set_ylabel("")
    plt.show()

    if SAVE_FIGURES:
        fig.savefig(FIG_DIR / DATASET2 / "time_compare.svg", format="svg", transparent=True, bbox_inches="tight")
    plt.show()
../_images/30f8b876ef73c2a23b15f87ede067b18de738c9a6b8837335c8eaed59c0b98cd.png