Performance comparison of inference on dyngen data

Performance comparison of inference on dyngen data#

Notebook compares metrics for velocity, latent time and GRN inference across different methods applied to dyngen-generated data.

import numpy as np
import pandas as pd
from scipy.stats import ttest_rel

import matplotlib.pyplot as plt
import mplscience
import seaborn as sns

from rgv_tools import DATA_DIR, FIG_DIR
from rgv_tools.plotting._significance import get_significance

Constants#

DATASET = "dyngen"
SAVE_FIGURES = True
if SAVE_FIGURES:
    (FIG_DIR / DATASET).mkdir(parents=True, exist_ok=True)
VELOCITY_METHODS = ["regvelo", "velovi", "scvelo", "velovae_fullvb", "velovae_vae", "tfvelo", "cell2fate", "unitvelo"]
VELOCITY_METHODS_TF = ["regvelo", "tfvelo"]
TIME_METHODS = ["regvelo", "velovi", "scvelo"]
GRN_METHODS = ["regvelo", "tfvelo", "grnboost2", "celloracle", "splicejac", "correlation"]
GRN_METHOD_PALETTE = {
    "regvelo": "#0173b2",
    "correlation": "#D3D3D3",
    "grnboost2": "#D3D3D3",  # 708090
    "celloracle": "#D3D3D3",  # A9A9A9
    "tfvelo": "#696969",  # 696969
    "splicejac": "#D3D3D3",
}

VELO_METHOD_PALETTE = {
    "regvelo": "#0173b2",
    "velovi": "#de8f05",
    "scvelo": "#029e73",
    "tfvelo": "#D3D3D3",
    "unitvelo": "#D3D3D3",
    "velovae_vae": "#D3D3D3",
    "velovae_fullvb": "#D3D3D3",
    "cell2fate": "#D3D3D3",
}

TIME_METHOD_PALETTE = {
    "regvelo": "#0173b2",
    "velovi": "#de8f05",
    "scvelo": "#029e73",
}

Data loading#

velocity benchmark#

scale_df = []
for scale in range(1, 5):
    df_tfvelo = pd.read_parquet(
        DATA_DIR / DATASET / f"complexity_{scale}" / "results" / "tfvelo_correlation_all.parquet"
    )
    velocity_correlation_tf = []

    for col in df_tfvelo:
        df_tf = df_tfvelo[col]

        df_tf = df_tf.reset_index(drop=True)

        nan_indices = df_tf[df_tf.isnull()].index

        tf_sub = df_tf.drop(nan_indices, axis=0)

        n_tf = len(tf_sub)

        velocity_correlation_tf.append(np.mean(tf_sub))

    correlation_df = []
    time_df = []
    grn_df = []

    for method in VELOCITY_METHODS:
        df = pd.read_parquet(DATA_DIR / DATASET / f"complexity_{scale}" / "results" / f"{method}_correlation.parquet")
        df.columns = f"{method}_" + df.columns
        correlation_df.append(df)

    for method in TIME_METHODS:
        df = pd.read_parquet(DATA_DIR / DATASET / f"complexity_{scale}" / "results" / f"{method}_correlation.parquet")
        df.columns = f"{method}_" + df.columns
        time_df.append(df)

    for method in GRN_METHODS:
        df = pd.read_parquet(DATA_DIR / DATASET / f"complexity_{scale}" / "results" / f"{method}_correlation.parquet")
        df.columns = f"{method}_" + df.columns
        grn_df.append(df)

    correlation_df = pd.concat(correlation_df, axis=1)
    time_df = pd.concat(time_df, axis=1)
    grn_df = pd.concat(grn_df, axis=1)

    if len(velocity_correlation_tf) < correlation_df.shape[0]:
        velocity_correlation_tf += [np.nan] * (correlation_df.shape[0] - len(velocity_correlation_tf))
    correlation_df["tfvelo_velocity"] = velocity_correlation_tf

    df = correlation_df.loc[:, correlation_df.columns.str.contains("velocity")]
    df.columns = df.columns.str.removesuffix("_velocity")
    df = pd.melt(df, var_name="method", value_name="correlation")
    df["correlation"] = (df["correlation"] + 1) / 2
    df["scale"] = str(scale)

    scale_df.append(df)
df = pd.concat(scale_df, axis=0)
df = df.reset_index(drop=True)

Velocity#

with mplscience.style_context():
    sns.set_style(style="whitegrid")
    fig, ax = plt.subplots(figsize=(14, 2.5))
    sns.violinplot(
        data=df,
        y="correlation",
        x="scale",
        hue="method",
        hue_order=VELOCITY_METHODS,
        palette=VELO_METHOD_PALETTE,
        ax=ax,
    )

    ax.set(
        xlabel="Pearson correlation",
        ylabel="Scale",
        yticks=ax.get_yticks(),
    )

    if SAVE_FIGURES:
        fig.savefig(
            FIG_DIR / DATASET / "velocity_benchmark.svg",
            format="svg",
            transparent=True,
            bbox_inches="tight",
        )

    plt.show()
../_images/53afdbd7771d8fe345439d798c2d6c8170c7b490e183cf1ff2b31af8cc91fb74.png

GRN benchmark#

scale_df = []
for scale in range(1, 5):
    grn_df = []
    for method in GRN_METHODS:
        df = pd.read_parquet(f"results/complexity_{scale}/{method}_correlation.parquet")
        df.columns = f"{method}_" + df.columns
        grn_df.append(df)

    grn_df = pd.concat(grn_df, axis=1)

    grn_df_sub = grn_df.copy()
    grn_df_sub = grn_df_sub.dropna()
    grn_df_sub

    df = grn_df_sub.loc[:, grn_df.columns.str.contains("grn")]
    df.columns = df.columns.str.removesuffix("_grn")
    df = pd.melt(df, var_name="method").rename(columns={"value": "correlation"})
    df["scale"] = str(scale)

    scale_df.append(df)
df = pd.concat(scale_df, axis=0)
df = df.reset_index(drop=True)
with mplscience.style_context():
    sns.set_style(style="whitegrid")
    fig, ax = plt.subplots(figsize=(14, 2.5))
    sns.violinplot(
        data=df, y="correlation", x="scale", hue="method", hue_order=GRN_METHODS, palette=GRN_METHOD_PALETTE, ax=ax
    )

    ax.set(
        ylabel="AUROC",
        xlabel="Scale",
        yticks=ax.get_yticks(),
    )

    if SAVE_FIGURES:
        fig.savefig(
            FIG_DIR / DATASET / "GRN_benchmark.svg",
            format="svg",
            transparent=True,
            bbox_inches="tight",
        )

    plt.show()
../_images/bfcb3ac69709bb4c7e80671272ef6d3500b8b183928d7e1fb003107a23c60646.png

Latent time#

scale_df = []
for scale in range(1, 5):
    time_df = []
    for method in TIME_METHODS:
        df = pd.read_parquet(DATA_DIR / DATASET / f"complexity_{scale}" / "results" / f"{method}_correlation.parquet")
        df.columns = f"{method}_" + df.columns
        time_df.append(df)

    time_df = pd.concat(time_df, axis=1)

    ttest_res = ttest_rel(
        time_df["regvelo_time"],
        time_df["velovi_time"],
        alternative="greater",
    )

    print(ttest_res)
    significance = get_significance(pvalue=ttest_res.pvalue)

    time_df_sub = time_df.copy()
    time_df_sub = time_df_sub.dropna()
    time_df_sub

    df = time_df_sub.loc[:, time_df.columns.str.contains("time")]
    df.columns = df.columns.str.removesuffix("_time")
    df = pd.melt(df, var_name="method", value_name="correlation")
    df["scale"] = str(scale)

    scale_df.append(df)
TtestResult(statistic=2.819809533651689, pvalue=0.004287626705348869, df=29)
TtestResult(statistic=6.817578655205317, pvalue=8.6925544495783e-08, df=29)
TtestResult(statistic=7.716392336845313, pvalue=8.252679776252025e-09, df=29)
TtestResult(statistic=5.498852429313436, pvalue=3.1735361936536186e-06, df=29)
df = pd.concat(scale_df, axis=0)
df = df.reset_index(drop=True)
with mplscience.style_context():
    sns.set_style(style="whitegrid")
    fig, ax = plt.subplots(figsize=(10, 2.5))
    sns.violinplot(
        data=df, y="correlation", x="scale", hue="method", hue_order=TIME_METHODS, palette=TIME_METHOD_PALETTE, ax=ax
    )

    ax.set(
        ylabel="Spearman correlation",
        xlabel="Scale",
        yticks=ax.get_yticks(),
    )
    ax.set_ylim(-0.3, 0.6)
    if SAVE_FIGURES:
        fig.savefig(
            FIG_DIR / DATASET / "time_benchmark.svg",
            format="svg",
            transparent=True,
            bbox_inches="tight",
        )

    plt.show()
../_images/419f9dc8ff2c5ac0a9db2c5577c40960b1e67cd35bd48e77b1394dfe13bfb17f.png