Benchmark cell fate and velocity identifiability#

Notebook for evaluate cellular velocity and cell fate identifiability

Library imports#

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import mplscience
import seaborn as sns

import cellrank as cr
import scanpy as sc
import scvelo as scv
from regvelo import REGVELOVI

from rgv_tools import DATA_DIR, FIG_DIR
from rgv_tools.benchmarking import set_output
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_csv from `anndata` is deprecated. Import anndata.io.read_csv instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_excel from `anndata` is deprecated. Import anndata.io.read_excel instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_hdf from `anndata` is deprecated. Import anndata.io.read_hdf instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_loom from `anndata` is deprecated. Import anndata.io.read_loom instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_mtx from `anndata` is deprecated. Import anndata.io.read_mtx instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_text from `anndata` is deprecated. Import anndata.io.read_text instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_umi_tools from `anndata` is deprecated. Import anndata.io.read_umi_tools instead.
  warnings.warn(msg, FutureWarning)

General settings#

%matplotlib inline
plt.rcParams["svg.fonttype"] = "none"
sns.reset_defaults()
sns.reset_orig()
scv.settings.set_figure_params("scvelo", dpi_save=400, dpi=80, transparent=True, fontsize=14, color_map="viridis")

Constants#

DATASET = "pancreatic_endocrinogenesis"
SAVE_DATA = True
if SAVE_DATA:
    (DATA_DIR / DATASET / "results").mkdir(parents=True, exist_ok=True)

SAVE_FIGURES = True
if SAVE_FIGURES:
    (FIG_DIR / DATASET).mkdir(parents=True, exist_ok=True)
TERMINAL_STATES = ["Alpha", "Beta", "Delta", "Epsilon"]

Function defination#

def cal_cell_fate(adata, terminal_states, cluster_key, n_states=7):
    """Cellrank's typical analysis pipeline."""
    vk = cr.kernels.VelocityKernel(adata).compute_transition_matrix()
    ck = cr.kernels.ConnectivityKernel(adata).compute_transition_matrix()
    combined_kernel = 0.8 * vk + 0.2 * ck
    estimator = cr.estimators.GPCCA(combined_kernel)

    ## evaluate the fate prob on original space
    estimator.compute_macrostates(n_states=n_states, n_cells=30, cluster_key=cluster_key)
    ## set a high number of states, and merge some of them and rename
    macrostate = estimator.macrostates.cat.categories.tolist()
    terminal_states_recovered = list(set(macrostate).intersection(terminal_states))

    ## check if all terminal state is coverred
    if len(set(terminal_states_recovered).intersection(terminal_states)) != len(terminal_states):
        ValueError("Terminal states are not recoverred in perturbation, please increase the number of the states!")
    estimator.set_terminal_states(terminal_states)
    estimator.compute_fate_probabilities(solver="direct")
    fb = estimator.fate_probabilities
    sampleID = adata.obs.index.tolist()
    fate_name = fb.names.tolist()
    fb = pd.DataFrame(fb, index=sampleID, columns=fate_name)

    return fb

Data loading#

adata = sc.read_h5ad(DATA_DIR / DATASET / "processed" / "adata_preprocessed.h5ad")
palette = dict(zip(adata.obs["clusters"].cat.categories, adata.uns["clusters_colors"]))

Model loading#

adata_m = []
velo_m = []
for nrun in range(0, 5):
    model = DATA_DIR / DATASET / "processed" / "perturb_repeat_runs" / f"rgv_model_{nrun}"
    ### load model
    vae = REGVELOVI.load(model, adata)
    # grn1 = reg_vae.module.v_encoder.fc1.weight.data.cpu().detach().clone()
    set_output(adata, vae, n_samples=30, batch_size=adata.n_obs)
    adata_m.append(adata.copy())
    velo_m.append(adata.layers["velocity"])
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/weixu.wang/miniconda3/envs/regvelo_test/li ...
INFO     File                                                                                                      
         /ictstr01/home/icb/weixu.wang/regulatory_velo/data/pancreatic_endocrine/processed/perturb_repeat_runs/rgv_
         model_0/model.pt already downloaded                                                                       
INFO     File                                                                                                      
         /ictstr01/home/icb/weixu.wang/regulatory_velo/data/pancreatic_endocrine/processed/perturb_repeat_runs/rgv_
         model_1/model.pt already downloaded                                                                       
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/weixu.wang/miniconda3/envs/regvelo_test/li ...
INFO     File                                                                                                      
         /ictstr01/home/icb/weixu.wang/regulatory_velo/data/pancreatic_endocrine/processed/perturb_repeat_runs/rgv_
         model_2/model.pt already downloaded                                                                       
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/weixu.wang/miniconda3/envs/regvelo_test/li ...
INFO     File                                                                                                      
         /ictstr01/home/icb/weixu.wang/regulatory_velo/data/pancreatic_endocrine/processed/perturb_repeat_runs/rgv_
         model_3/model.pt already downloaded                                                                       
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/weixu.wang/miniconda3/envs/regvelo_test/li ...
INFO     File                                                                                                      
         /ictstr01/home/icb/weixu.wang/regulatory_velo/data/pancreatic_endocrine/processed/perturb_repeat_runs/rgv_
         model_4/model.pt already downloaded                                                                       
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/icb/weixu.wang/miniconda3/envs/regvelo_test/li ...

Cellular velocity#

velo_arrays = np.stack(velo_m, axis=2)
N = adata.shape[0]
# Calculate pairwise cell correlation among two arrays
correlations = []
for i in range(velo_arrays.shape[2] + 1):
    for j in range(i + 1, (velo_arrays.shape[2])):
        if i != j:
            correlation_matrix = np.corrcoef(velo_arrays[:, :, i], velo_arrays[:, :, j])
            # Extract upper triangle part (excluding diagonal)
            correlation_matrix = correlation_matrix[0:N, N : (2 * N)]
            corr = np.diag(correlation_matrix)
            corr = [
                np.mean(np.array(corr)[adata.obs["clusters"] == "Alpha"]),
                np.mean(np.array(corr)[adata.obs["clusters"] == "Beta"]),
                np.mean(np.array(corr)[adata.obs["clusters"] == "Delta"]),
                np.mean(np.array(corr)[adata.obs["clusters"] == "Epsilon"]),
            ]
            correlations.append(corr)
correlations = np.array(correlations)

df = pd.DataFrame(
    {
        "Correlation": correlations[:, 0].tolist()
        + correlations[:, 1].tolist()
        + correlations[:, 2].tolist()
        + correlations[:, 3].tolist(),
        "Group": np.repeat(TERMINAL_STATES, correlations.shape[0]),
    }
)

with mplscience.style_context():
    fig, ax = plt.subplots(figsize=(4, 4))

    sns.boxplot(data=df, x="Group", y="Correlation", palette=palette, ax=ax)

    plt.xlabel("Cell Type")
    plt.ylabel("PCC")
    plt.ylim(0, 1)
    plt.tight_layout()
    plt.show()
../_images/851013134bb7c852d0d0f66da9d7e9c8a4bfb01c32e94fcb81f95763c8ff9248.png
df_rgv_velo = df.copy()

Cell fate#

fate_list = []
for n, ad in enumerate(adata_m):
    fate = cal_cell_fate(ad, terminal_states=TERMINAL_STATES[:3], cluster_key="clusters", n_states=8)
    fate.columns = fate.columns + f"_{n}"
    fate_list.append(fate)
[0]PETSC ERROR: ------------------------------------------------------------------------
[0]PETSC ERROR: Caught signal number 13 Broken Pipe: Likely while reading or writing to a socket
[0]PETSC ERROR: Try option -start_in_debugger or -on_error_attach_debugger
[0]PETSC ERROR: or see https://petsc.org/release/faq/#valgrind and https://petsc.org/release/faq/
[0]PETSC ERROR: configure using --with-debugging=yes, recompile, link, and run 
[0]PETSC ERROR: to get more information on the crash.
Abort(59) on node 0 (rank 0 in comm 0): application called MPI_Abort(MPI_COMM_WORLD, 59) - process 0
[0]PETSC ERROR: ------------------------------------------------------------------------
[0]PETSC ERROR: Caught signal number 13 Broken Pipe: Likely while reading or writing to a socket
[0]PETSC ERROR: Try option -start_in_debugger or -on_error_attach_debugger
[0]PETSC ERROR: or see https://petsc.org/release/faq/#valgrind and https://petsc.org/release/faq/
[0]PETSC ERROR: configure using --with-debugging=yes, recompile, link, and run 
[0]PETSC ERROR: to get more information on the crash.
Abort(59) on node 0 (rank 0 in comm 0): application called MPI_Abort(MPI_COMM_WORLD, 59) - process 0
[0]PETSC ERROR: ------------------------------------------------------------------------
[0]PETSC ERROR: Caught signal number 13 Broken Pipe: Likely while reading or writing to a socket
[0]PETSC ERROR: Try option -start_in_debugger or -on_error_attach_debugger
[0]PETSC ERROR: or see https://petsc.org/release/faq/#valgrind and https://petsc.org/release/faq/
[0]PETSC ERROR: configure using --with-debugging=yes, recompile, link, and run 
[0]PETSC ERROR: to get more information on the crash.
Abort(59) on node 0 (rank 0 in comm 0): application called MPI_Abort(MPI_COMM_WORLD, 59) - process 0
[0]PETSC ERROR: ------------------------------------------------------------------------
[0]PETSC ERROR: Caught signal number 13 Broken Pipe: Likely while reading or writing to a socket
[0]PETSC ERROR: Try option -start_in_debugger or -on_error_attach_debugger
[0]PETSC ERROR: or see https://petsc.org/release/faq/#valgrind and https://petsc.org/release/faq/
[0]PETSC ERROR: configure using --with-debugging=yes, recompile, link, and run 
[0]PETSC ERROR: to get more information on the crash.
Abort(59) on node 0 (rank 0 in comm 0): application called MPI_Abort(MPI_COMM_WORLD, 59) - process 0
[0]PETSC ERROR: ------------------------------------------------------------------------
[0]PETSC ERROR: Caught signal number 13 Broken Pipe: Likely while reading or writing to a socket
[0]PETSC ERROR: Try option -start_in_debugger or -on_error_attach_debugger
[0]PETSC ERROR: or see https://petsc.org/release/faq/#valgrind and https://petsc.org/release/faq/
[0]PETSC ERROR: configure using --with-debugging=yes, recompile, link, and run 
[0]PETSC ERROR: to get more information on the crash.
Abort(59) on node 0 (rank 0 in comm 0): application called MPI_Abort(MPI_COMM_WORLD, 59) - process 0
[0]PETSC ERROR: ------------------------------------------------------------------------
[0]PETSC ERROR: Caught signal number 13 Broken Pipe: Likely while reading or writing to a socket
[0]PETSC ERROR: Try option -start_in_debugger or -on_error_attach_debugger
[0]PETSC ERROR: or see https://petsc.org/release/faq/#valgrind and https://petsc.org/release/faq/
[0]PETSC ERROR: configure using --with-debugging=yes, recompile, link, and run 
[0]PETSC ERROR: to get more information on the crash.
Abort(59) on node 0 (rank 0 in comm 0): application called MPI_Abort(MPI_COMM_WORLD, 59) - process 0
[0]PETSC ERROR: ------------------------------------------------------------------------
[0]PETSC ERROR: Caught signal number 13 Broken Pipe: Likely while reading or writing to a socket
[0]PETSC ERROR: Try option -start_in_debugger or -on_error_attach_debugger
[0]PETSC ERROR: or see https://petsc.org/release/faq/#valgrind and https://petsc.org/release/faq/
[0]PETSC ERROR: configure using --with-debugging=yes, recompile, link, and run 
[0]PETSC ERROR: to get more information on the crash.
Abort(59) on node 0 (rank 0 in comm 0): application called MPI_Abort(MPI_COMM_WORLD, 59) - process 0
[0]PETSC ERROR: ------------------------------------------------------------------------
[0]PETSC ERROR: Caught signal number 13 Broken Pipe: Likely while reading or writing to a socket
[0]PETSC ERROR: Try option -start_in_debugger or -on_error_attach_debugger
[0]PETSC ERROR: or see https://petsc.org/release/faq/#valgrind and https://petsc.org/release/faq/
[0]PETSC ERROR: configure using --with-debugging=yes, recompile, link, and run 
[0]PETSC ERROR: to get more information on the crash.
Abort(59) on node 0 (rank 0 in comm 0): application called MPI_Abort(MPI_COMM_WORLD, 59) - process 0
fate_all = pd.concat(fate_list, axis=1)

correlation_results = {}
for fate in TERMINAL_STATES[:3]:
    fate_cols = [f"{fate}_{i}" for i in range(5)]
    fate_df = fate_all[fate_cols]
    correlation_results[f"{fate}"] = fate_df.corr()
corr_all = []
for fate in TERMINAL_STATES[:3]:  # Assuming there are 4 metrics
    correlation_matrix = correlation_results[f"{fate}"]
    corr_pairwise = correlation_matrix.where(np.triu(np.ones(correlation_matrix.shape), k=1).astype(bool))
    corr_all = corr_all + corr_pairwise.stack().tolist()

df = pd.DataFrame(
    {"Correlation": corr_all, "Group": np.repeat(TERMINAL_STATES[:3], len(corr_pairwise.stack().tolist()))}
)
with mplscience.style_context():
    fig, ax = plt.subplots(figsize=(4, 4))

    sns.boxplot(data=df, x="Group", y="Correlation", palette=palette, ax=ax)

    plt.xlabel("Cell Type")
    plt.ylabel("PCC")
    plt.ylim(0.75, 1)
    plt.tight_layout()
    plt.show()
../_images/e84865cc4610e707c21497bf1ca36e88c3b7b83f71d782f3fb4d99aaed2d1f36.png
df_rgv_fb = df.copy()

Using veloVI as cell fate analysis baseline#

fate_list = []
for n in range(5):
    ad = sc.read_h5ad(DATA_DIR / DATASET / "processed" / f"adata_velovi_run{n}")
    fate = cal_cell_fate(ad, terminal_states=TERMINAL_STATES[:3], cluster_key="clusters", n_states=10)
    fate.columns = fate.columns + f"_{n}"
    fate_list.append(fate)
[0]PETSC ERROR: ------------------------------------------------------------------------
[0]PETSC ERROR: Caught signal number 13 Broken Pipe: Likely while reading or writing to a socket
[0]PETSC ERROR: Try option -start_in_debugger or -on_error_attach_debugger
[0]PETSC ERROR: or see https://petsc.org/release/faq/#valgrind and https://petsc.org/release/faq/
[0]PETSC ERROR: configure using --with-debugging=yes, recompile, link, and run 
WARNING: Using `10` components would split a block of complex conjugate eigenvalues. Using `n_components=11`
[0]PETSC ERROR: 
WARNING: Unable to compute macrostates with `n_states=10` because it will split complex conjugate eigenvalues. Using `n_states=11`
WARNING: Using `10` components would split a block of complex conjugate eigenvalues. Using `n_components=11`
WARNING: Unable to compute macrostates with `n_states=10` because it will split complex conjugate eigenvalues. Using `n_states=11`
WARNING: Using `10` components would split a block of complex conjugate eigenvalues. Using `n_components=11`
[0]PETSC ERROR: 
WARNING: Unable to compute macrostates with `n_states=10` because it will split complex conjugate eigenvalues. Using `n_states=11`
[0]PETSC ERROR: 
[0]PETSC ERROR: 
WARNING: The following terminal states have different number of cells than requested (30): {'Ductal_2': 26}
fate_all = pd.concat(fate_list, axis=1)

correlation_results = {}
for fate in TERMINAL_STATES[:3]:
    fate_cols = [f"{fate}_{i}" for i in range(5)]
    fate_df = fate_all[fate_cols]
    correlation_results[f"{fate}"] = fate_df.corr()
corr_all = []
for fate in TERMINAL_STATES[:3]:  # Assuming there are 4 metrics
    correlation_matrix = correlation_results[f"{fate}"]
    corr_pairwise = correlation_matrix.where(np.triu(np.ones(correlation_matrix.shape), k=1).astype(bool))
    corr_all = corr_all + corr_pairwise.stack().tolist()

df = pd.DataFrame(
    {"Correlation": corr_all, "Group": np.repeat(TERMINAL_STATES[:3], len(corr_pairwise.stack().tolist()))}
)
with mplscience.style_context():
    fig, ax = plt.subplots(figsize=(4, 4))

    sns.boxplot(data=df, x="Group", y="Correlation", palette=palette, ax=ax)

    plt.xlabel("Cell Type")
    plt.ylabel("PCC")
    plt.ylim(0.75, 1)
    plt.tight_layout()
    plt.show()
../_images/2c30b065d6cbb463050e826f3a160ebdff38d3a92083f3d87ee90598cfcde847.png
df_vi_fb = df.copy()
# Create figure with two subplots
with mplscience.style_context():  # Entering the custom style context
    sns.set_style("whitegrid")
    fig, axes = plt.subplots(1, 3, figsize=(9, 4), sharey=True)  # 1 row, 2 columns

    # Loop over the subplots and plot data for each run
    ax = axes[0]
    sns.boxplot(y="Correlation", x="Group", data=df_rgv_velo, ax=ax, palette=palette)
    ax.set_title("Velocity")
    ax.set_xlabel("")
    ax.set_ylabel("PCC")

    ax = axes[1]
    sns.boxplot(y="Correlation", x="Group", data=df_rgv_fb, ax=ax, palette=palette)
    ax.set_title("Cell fate")
    ax.set_xlabel("")
    ax.set_ylabel("")

    ax = axes[2]
    sns.boxplot(y="Correlation", x="Group", data=df_vi_fb, ax=ax, palette=palette)
    ax.set_title("Cell fate (veloVI)")
    ax.set_xlabel("")
    ax.set_ylabel("")

    # Set a common ylabel
    axes[0].set_ylabel("PCC")
    plt.ylim(0.6, 1.05)

    # Show the plot
    plt.tight_layout()

    if SAVE_FIGURES:
        save_path = FIG_DIR / DATASET / "Identifiability_Benchmark.svg"
        fig.savefig(save_path, format="svg", transparent=True, bbox_inches="tight")
    plt.show()
../_images/4d966c72269361e8c823c8f62a1f8edf6d75e69eaa925b92e598595cee12ac6d.png
# Create figure with two subplots
with mplscience.style_context():  # Entering the custom style context
    sns.set_style("whitegrid")
    fig, axes = plt.subplots(1, 3, figsize=(9, 4), sharey=True)  # 1 row, 2 columns

    # Loop over the subplots and plot data for each run
    ax = axes[0]
    sns.boxplot(y="Correlation", x="Group", data=df_rgv_velo, ax=ax, palette=palette)
    ax.set_title("Velocity")
    ax.set_xlabel("")
    ax.set_ylabel("PCC")

    ax = axes[1]
    sns.boxplot(y="Correlation", x="Group", data=df_rgv_fb, ax=ax, palette=palette)
    ax.set_title("Cell fate")
    ax.set_xlabel("")
    ax.set_ylabel("")

    ax = axes[2]
    sns.boxplot(y="Correlation", x="Group", data=df_vi_fb, ax=ax, palette=palette)
    ax.set_title("Cell fate (veloVI)")
    ax.set_xlabel("")
    ax.set_ylabel("")

    # Set a common ylabel
    axes[0].set_ylabel("PCC")
    plt.ylim(0.8, 1.01)

    # Show the plot
    plt.tight_layout()

    if SAVE_FIGURES:
        save_path = FIG_DIR / DATASET / "Identifiability_Benchmark.svg"
        fig.savefig(save_path, format="svg", transparent=True, bbox_inches="tight")
    plt.show()
../_images/b6dda7f5326aad81c0bba0bd1828ea2394ecfdd8e950c4f81cccebb54ee1cfee.png