veloVAE (VAE) benchmark on cell cycle data#

Notebook benchmarks velocity, latent time inference, and cross boundary correctness using veloVAE (VAE) on cell cycle data.

import velovae as vv

import numpy as np
import pandas as pd
import torch

import anndata as ad
import scvelo as scv
from cellrank.kernels import VelocityKernel

from rgv_tools import DATA_DIR
from rgv_tools.benchmarking import get_time_correlation
2024-12-12 14:30:14.221112: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-12-12 14:30:19.578856: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1734010220.050183 1024545 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1734010220.102064 1024545 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-12 14:30:20.515384: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_csv from `anndata` is deprecated. Import anndata.io.read_csv instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_excel from `anndata` is deprecated. Import anndata.io.read_excel instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_hdf from `anndata` is deprecated. Import anndata.io.read_hdf instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_loom from `anndata` is deprecated. Import anndata.io.read_loom instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_mtx from `anndata` is deprecated. Import anndata.io.read_mtx instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_text from `anndata` is deprecated. Import anndata.io.read_text instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_umi_tools from `anndata` is deprecated. Import anndata.io.read_umi_tools instead.
  warnings.warn(msg, FutureWarning)

General settings#

scv.settings.verbosity = 3

Constants#

torch.manual_seed(0)
np.random.seed(0)
DATASET = "cell_cycle"
STATE_TRANSITIONS = [("G1", "S"), ("S", "G2M")]
SAVE_DATA = True
if SAVE_DATA:
    (DATA_DIR / DATASET / "results").mkdir(parents=True, exist_ok=True)
    (DATA_DIR / DATASET / "processed" / "velovae_vae").mkdir(parents=True, exist_ok=True)

Data loading#

adata = ad.io.read_h5ad(DATA_DIR / DATASET / "processed" / "adata_processed.h5ad")
adata
AnnData object with n_obs × n_vars = 1146 × 395
    obs: 'phase', 'fucci_time', 'initial_size_unspliced', 'initial_size_spliced', 'initial_size', 'n_counts'
    var: 'ensum_id', 'gene_count_corr', 'means', 'dispersions', 'dispersions_norm', 'highly_variable', 'velocity_gamma', 'velocity_qreg_ratio', 'velocity_r2', 'velocity_genes'
    uns: 'log1p', 'neighbors', 'pca', 'umap', 'velocity_params'
    obsm: 'X_pca', 'X_umap'
    varm: 'PCs', 'true_skeleton'
    layers: 'Ms', 'Mu', 'spliced', 'total', 'unspliced', 'velocity'
    obsp: 'connectivities', 'distances'

Velocity pipeline#

vae = vv.VAE(adata, tmax=20, dim_z=5, device="cuda:0")
config = {}
vae.train(adata, config=config, plot=False, embed="pca")

## output velocity to adata object
vae.save_anndata(adata, "vae", DATA_DIR / DATASET / "processed" / "velovae_vae", file_name="velovae.h5ad")
Estimating ODE parameters...
Detected 338 velocity genes.
Estimating the variance...
Initialization using the steady-state and dynamical models.
Reinitialize the regular ODE parameters based on estimated global latent time.
3 clusters detected based on gene co-expression.
(0.51, 0.7800862095516045), (0.49, 0.3450951793158637)
(0.50, 0.7100356892462681), (0.50, 0.3549154711787042)
KS-test result: [0. 0. 1.]
Initial induction: 230, repression: 165/395
Learning Rate based on Data Sparsity: 0.0000
--------------------------- Train a VeloVAE ---------------------------
*********        Creating Training/Validation Datasets        *********
*********                      Finished.                      *********
*********                 Creating optimizers                 *********
*********                      Finished.                      *********
*********                    Start training                   *********
*********                      Stage  1                       *********
Total Number of Iterations Per Epoch: 7, test iteration: 12
*********       Stage 1: Early Stop Triggered at epoch 7.       *********
*********                      Stage  2                       *********
*********             Velocity Refinement Round 1             *********
Percentage of Invalid Sets: 0.026
Average Set Size: 23
*********     Round 1: Early Stop Triggered at epoch 83.    *********
Change in noise variance: 0.0614
*********             Velocity Refinement Round 2             *********
*********     Round 2: Early Stop Triggered at epoch 92.    *********
Change in noise variance: 0.0010
Change in x0: 0.8648
*********             Velocity Refinement Round 3             *********
*********     Round 3: Early Stop Triggered at epoch 101.    *********
Change in noise variance: 0.0003
Change in x0: 0.6694
*********             Velocity Refinement Round 4             *********
*********     Round 4: Early Stop Triggered at epoch 110.    *********
Change in noise variance: 0.0000
Change in x0: 0.5544
*********             Velocity Refinement Round 5             *********
*********     Round 5: Early Stop Triggered at epoch 119.    *********
Change in noise variance: 0.0000
Change in x0: 0.5010
*********             Velocity Refinement Round 6             *********
*********     Round 6: Early Stop Triggered at epoch 128.    *********
Change in noise variance: 0.0000
Change in x0: 0.4104
*********             Velocity Refinement Round 7             *********
*********     Round 7: Early Stop Triggered at epoch 137.    *********
Change in noise variance: 0.0000
Change in x0: 0.4021
*********             Velocity Refinement Round 8             *********
Stage 2: Early Stop Triggered at round 7.
*********              Finished. Total Time =   0 h :  0 m : 14 s             *********
Final: Train ELBO = -18372.410,	Test ELBO = -18900.932
adata.layers["velocity"] = adata.layers["vae_velocity"].copy()
time_correlation = [get_time_correlation(ground_truth=adata.obs["fucci_time"], estimated=adata.obs["vae_time"])]
scv.tl.velocity_graph(adata, vkey="velocity", n_jobs=1)
scv.tl.velocity_confidence(adata, vkey="velocity")
computing velocity graph (using 1/112 cores)
    finished (0:00:01) --> added 
    'velocity_graph', sparse matrix with cosine correlations (adata.uns)
--> added 'velocity_length' (adata.obs)
--> added 'velocity_confidence' (adata.obs)
--> added 'velocity_confidence_transition' (adata.obs)

Cross-boundary correctness#

vk = VelocityKernel(adata, vkey="velocity").compute_transition_matrix()

cluster_key = "phase"
rep = "X_pca"

score_df = []
for source, target in STATE_TRANSITIONS:
    cbc = vk.cbc(source=source, target=target, cluster_key=cluster_key, rep=rep)

    score_df.append(
        pd.DataFrame(
            {
                "State transition": [f"{source} - {target}"] * len(cbc),
                "CBC": cbc,
            }
        )
    )
score_df = pd.concat(score_df)

Data saving#

if SAVE_DATA:
    pd.DataFrame({"time": time_correlation}, index=adata.obs_names).to_parquet(
        path=DATA_DIR / DATASET / "results" / "velovae_vae_correlation.parquet"
    )
    adata.obs[["velocity_confidence"]].to_parquet(
        path=DATA_DIR / DATASET / "results" / "velovae_vae_confidence.parquet"
    )
    score_df.to_parquet(path=DATA_DIR / DATASET / "results" / "velovae_vae_cbc.parquet")