veloVAE (fullvb) benchmark on cell cycle data#
Notebook benchmarks velocity, latent time inference, and cross boundary correctness using veloVAE (fullvb) on cell cycle data.
import velovae as vv
import numpy as np
import pandas as pd
import torch
import anndata as ad
import scvelo as scv
from cellrank.kernels import VelocityKernel
from rgv_tools import DATA_DIR
from rgv_tools.benchmarking import get_time_correlation
2024-12-12 14:25:24.220575: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-12-12 14:25:24.235599: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1734009924.253182 1021162 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1734009924.258607 1021162 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-12 14:25:24.280065: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_csv from `anndata` is deprecated. Import anndata.io.read_csv instead.
warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_excel from `anndata` is deprecated. Import anndata.io.read_excel instead.
warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_hdf from `anndata` is deprecated. Import anndata.io.read_hdf instead.
warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_loom from `anndata` is deprecated. Import anndata.io.read_loom instead.
warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_mtx from `anndata` is deprecated. Import anndata.io.read_mtx instead.
warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_text from `anndata` is deprecated. Import anndata.io.read_text instead.
warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_umi_tools from `anndata` is deprecated. Import anndata.io.read_umi_tools instead.
warnings.warn(msg, FutureWarning)
General settings#
scv.settings.verbosity = 3
Constants#
torch.manual_seed(0)
np.random.seed(0)
DATASET = "cell_cycle"
STATE_TRANSITIONS = [("G1", "S"), ("S", "G2M")]
SAVE_DATA = True
if SAVE_DATA:
(DATA_DIR / DATASET / "results").mkdir(parents=True, exist_ok=True)
(DATA_DIR / DATASET / "processed" / "velovae_vae").mkdir(parents=True, exist_ok=True)
Data loading#
adata = ad.io.read_h5ad(DATA_DIR / DATASET / "processed" / "adata_processed.h5ad")
adata
AnnData object with n_obs × n_vars = 1146 × 395
obs: 'phase', 'fucci_time', 'initial_size_unspliced', 'initial_size_spliced', 'initial_size', 'n_counts'
var: 'ensum_id', 'gene_count_corr', 'means', 'dispersions', 'dispersions_norm', 'highly_variable', 'velocity_gamma', 'velocity_qreg_ratio', 'velocity_r2', 'velocity_genes'
uns: 'log1p', 'neighbors', 'pca', 'umap', 'velocity_params'
obsm: 'X_pca', 'X_umap'
varm: 'PCs', 'true_skeleton'
layers: 'Ms', 'Mu', 'spliced', 'total', 'unspliced', 'velocity'
obsp: 'connectivities', 'distances'
Velocity pipeline#
rate_prior = {"alpha": (0.0, 1.0), "beta": (0.0, 0.5), "gamma": (0.0, 0.5)}
full_vb = vv.VAE(adata, tmax=20, dim_z=5, device="cuda:0", full_vb=True, rate_prior=rate_prior)
config = {}
full_vb.train(adata, config=config, plot=False, embed="pca")
## output velocity to adata object
full_vb.save_anndata(adata, "fullvb", DATA_DIR / DATASET / "processed" / "velovae_vae", file_name="velovae_fullvb.h5ad")
Estimating ODE parameters...
Detected 338 velocity genes.
Estimating the variance...
Initialization using the steady-state and dynamical models.
Reinitialize the regular ODE parameters based on estimated global latent time.
3 clusters detected based on gene co-expression.
(0.51, 0.7800862095516045), (0.49, 0.3450951793158637)
(0.50, 0.7100356892462681), (0.50, 0.3549154711787042)
KS-test result: [0. 0. 1.]
Initial induction: 230, repression: 165/395
Learning Rate based on Data Sparsity: 0.0000
--------------------------- Train a VeloVAE ---------------------------
********* Creating Training/Validation Datasets *********
********* Finished. *********
********* Creating optimizers *********
********* Finished. *********
********* Start training *********
********* Stage 1 *********
Total Number of Iterations Per Epoch: 7, test iteration: 12
********* Stage 1: Early Stop Triggered at epoch 7. *********
********* Stage 2 *********
********* Velocity Refinement Round 1 *********
Percentage of Invalid Sets: 0.031
Average Set Size: 23
********* Round 1: Early Stop Triggered at epoch 122. *********
Change in noise variance: 0.0599
********* Velocity Refinement Round 2 *********
********* Round 2: Early Stop Triggered at epoch 131. *********
Change in noise variance: 0.0011
Change in x0: 0.8054
********* Velocity Refinement Round 3 *********
********* Round 3: Early Stop Triggered at epoch 140. *********
Change in noise variance: 0.0003
Change in x0: 0.6686
********* Velocity Refinement Round 4 *********
********* Round 4: Early Stop Triggered at epoch 149. *********
Change in noise variance: 0.0000
Change in x0: 0.5644
********* Velocity Refinement Round 5 *********
********* Round 5: Early Stop Triggered at epoch 158. *********
Change in noise variance: 0.0000
Change in x0: 0.5022
********* Velocity Refinement Round 6 *********
********* Round 6: Early Stop Triggered at epoch 167. *********
Change in noise variance: 0.0000
Change in x0: 0.4278
********* Velocity Refinement Round 7 *********
********* Round 7: Early Stop Triggered at epoch 176. *********
Change in noise variance: 0.0000
Change in x0: 0.4016
********* Velocity Refinement Round 8 *********
********* Round 8: Early Stop Triggered at epoch 185. *********
Change in noise variance: 0.0000
Change in x0: 0.3699
********* Velocity Refinement Round 9 *********
********* Round 9: Early Stop Triggered at epoch 194. *********
Change in noise variance: 0.0000
Change in x0: 0.3596
********* Velocity Refinement Round 10 *********
********* Round 10: Early Stop Triggered at epoch 203. *********
Change in noise variance: 0.0000
Change in x0: 0.3600
********* Velocity Refinement Round 11 *********
Stage 2: Early Stop Triggered at round 10.
********* Finished. Total Time = 0 h : 0 m : 24 s *********
Final: Train ELBO = -18012.699, Test ELBO = -18553.311
adata.layers["velocity"] = adata.layers["fullvb_velocity"].copy()
time_correlation = [get_time_correlation(ground_truth=adata.obs["fucci_time"], estimated=adata.obs["fullvb_time"])]
scv.tl.velocity_graph(adata, vkey="velocity", n_jobs=1)
scv.tl.velocity_confidence(adata, vkey="velocity")
computing velocity graph (using 1/112 cores)
finished (0:00:01) --> added
'velocity_graph', sparse matrix with cosine correlations (adata.uns)
--> added 'velocity_length' (adata.obs)
--> added 'velocity_confidence' (adata.obs)
--> added 'velocity_confidence_transition' (adata.obs)
Cross-boundary correctness#
vk = VelocityKernel(adata, vkey="velocity").compute_transition_matrix()
cluster_key = "phase"
rep = "X_pca"
score_df = []
for source, target in STATE_TRANSITIONS:
cbc = vk.cbc(source=source, target=target, cluster_key=cluster_key, rep=rep)
score_df.append(
pd.DataFrame(
{
"State transition": [f"{source} - {target}"] * len(cbc),
"CBC": cbc,
}
)
)
score_df = pd.concat(score_df)
Data saving#
if SAVE_DATA:
pd.DataFrame({"time": time_correlation}, index=adata.obs_names).to_parquet(
path=DATA_DIR / DATASET / "results" / "velovae_fullvb_correlation.parquet"
)
adata.obs[["velocity_confidence"]].to_parquet(
path=DATA_DIR / DATASET / "results" / "velovae_fullvb_confidence.parquet"
)
score_df.to_parquet(path=DATA_DIR / DATASET / "results" / "velovae_fullvb_cbc.parquet")