Using veloVI as baseline for RegVelo identifiability test#
Notebook runs preprocessing and using veloVI as baseline.
Library imports#
import numpy as np
import pandas as pd
import scanpy as sc
import scvelo as scv
import scvi
from velovi import preprocess_data, VELOVI
from rgv_tools import DATA_DIR
from rgv_tools.benchmarking import (
get_time_correlation,
get_velocity_correlation,
set_output,
)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_csv from `anndata` is deprecated. Import anndata.io.read_csv instead.
warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_loom from `anndata` is deprecated. Import anndata.io.read_loom instead.
warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_text from `anndata` is deprecated. Import anndata.io.read_text instead.
warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing CSCDataset from `anndata.experimental` is deprecated. Import anndata.abc.CSCDataset instead.
warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing CSRDataset from `anndata.experimental` is deprecated. Import anndata.abc.CSRDataset instead.
warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_elem from `anndata.experimental` is deprecated. Import anndata.io.read_elem instead.
warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_csv from `anndata` is deprecated. Import anndata.io.read_csv instead.
warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_excel from `anndata` is deprecated. Import anndata.io.read_excel instead.
warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_hdf from `anndata` is deprecated. Import anndata.io.read_hdf instead.
warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_loom from `anndata` is deprecated. Import anndata.io.read_loom instead.
warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_mtx from `anndata` is deprecated. Import anndata.io.read_mtx instead.
warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_text from `anndata` is deprecated. Import anndata.io.read_text instead.
warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/regvelo_test/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_umi_tools from `anndata` is deprecated. Import anndata.io.read_umi_tools instead.
warnings.warn(msg, FutureWarning)
General settings#
scvi.settings.seed = 0
[rank: 0] Seed set to 0
Constants#
DATASET = "dyngen"
SAVE_DATA = True
if SAVE_DATA:
(DATA_DIR / DATASET / "processed").mkdir(parents=True, exist_ok=True)
(DATA_DIR / DATASET / "results").mkdir(parents=True, exist_ok=True)
Function definitions#
def update_data(adata) -> None:
"""Update dyngen-simulated data to include only relevant information in the standard format."""
adata.X = adata.layers["counts_spliced"]
adata.layers["unspliced"] = adata.layers.pop("counts_unspliced")
adata.layers["spliced"] = adata.layers.pop("counts_spliced")
adata.layers["true_velocity"] = adata.layers.pop("rna_velocity")
adata.layers["true_velocity"] = adata.layers["true_velocity"].toarray()
adata.layers["unspliced_raw"] = adata.layers["unspliced"].copy()
adata.layers["spliced_raw"] = adata.layers["spliced"].copy()
del adata.layers["counts_protein"]
del adata.layers["logcounts"]
adata.obs.rename(columns={"sim_time": "true_time"}, inplace=True)
adata.obs.drop(columns=["step_ix", "simulation_i"], inplace=True)
adata.var.rename(
columns={"transcription_rate": "true_alpha", "splicing_rate": "true_beta", "mrna_decay_rate": "true_gamma"},
inplace=True,
)
columns_to_keep = ["true_alpha", "true_beta", "true_gamma", "is_tf"]
adata.var.drop(columns=adata.var.columns.difference(columns_to_keep), inplace=True)
slots = list(adata.uns.keys())
for slot in slots:
if slot in ["network", "regulatory_network", "skeleton", "regulators", "targets"]:
adata.uns[f"true_{slot}"] = adata.uns.pop(slot)
else:
del adata.uns[slot]
adata.obsm["true_sc_network"] = adata.obsm.pop("regulatory_network_sc")
del adata.obsm["dimred"]
adata.obs_names = adata.obs_names.str.replace("cell", "cell_")
def get_sc_grn(adata):
"""Compute cell-specific GRNs."""
true_sc_grn = []
for cell_id in range(adata.n_obs):
grn = np.zeros([adata.n_vars, adata.n_vars])
df = adata.uns["true_regulatory_network"][["regulator", "target"]].copy()
df["value"] = adata.obsm["true_sc_network"][cell_id, :].toarray().squeeze()
df = pd.pivot(df, index="regulator", columns="target", values="value").fillna(0)
grn[np.ix_(adata.var_names.get_indexer(df.index), adata.var_names.get_indexer(df.columns))] = df.values
true_sc_grn.append(grn)
return np.dstack(true_sc_grn)
Data loading#
adata = sc.read_h5ad(DATA_DIR / DATASET / "raw" / "dataset_dyngen_sim.h5ad")
adata
AnnData object with n_obs × n_vars = 1000 × 200
obs: 'step_ix', 'simulation_i', 'sim_time'
var: 'module_id', 'basal', 'burn', 'independence', 'color', 'is_tf', 'is_hk', 'transcription_rate', 'splicing_rate', 'translation_rate', 'mrna_halflife', 'protein_halflife', 'mrna_decay_rate', 'protein_decay_rate', 'max_premrna', 'max_mrna', 'max_protein', 'mol_premrna', 'mol_mrna', 'mol_protein'
uns: 'network', 'regulators', 'regulatory_network', 'regulatory_network_regulators', 'regulatory_network_targets', 'skeleton', 'targets', 'traj_dimred_segments', 'traj_milestone_network', 'traj_progressions'
obsm: 'dimred', 'regulatory_network_sc'
layers: 'counts_protein', 'counts_spliced', 'counts_unspliced', 'logcounts', 'rna_velocity'
Preprocessing pipeline#
update_data(adata=adata)
adata.uns["true_sc_grn"] = get_sc_grn(adata=adata)
scv.pp.filter_and_normalize(adata, min_shared_counts=10, log=False)
sc.pp.log1p(adata)
sc.tl.pca(adata)
sc.pp.neighbors(adata, n_neighbors=30, n_pcs=30)
scv.pp.moments(adata)
adata = preprocess_data(adata, filter_on_r2=True)
mask = pd.Index(adata.uns["true_regulators"]).isin(adata.var_names)
for uns_key in ["network", "skeleton", "sc_grn"]:
adata.uns[f"true_{uns_key}"] = adata.uns[f"true_{uns_key}"][np.ix_(mask, mask)]
adata.write_zarr(DATA_DIR / DATASET / "processed" / "processed_sim.zarr")
Filtered out 3 genes that are detected 10 counts (shared).
Normalized count data: X, spliced, unspliced.
2024-12-14 22:59:38.918377: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1734213579.898155 42271 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1734213580.312006 42271 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
computing moments based on connectivities
finished (0:00:00) --> added
'Ms' and 'Mu', moments of un/spliced abundances (adata.layers)
computing velocities
finished (0:00:00) --> added
'velocity', velocity vectors for each individual cell (adata.layers)
Velocity pipeline#
velocity_correlation = []
time_correlation = []
for i in range(5):
VELOVI.setup_anndata(adata, spliced_layer="Ms", unspliced_layer="Mu")
vae = VELOVI(adata)
vae.train(max_epochs=1500)
set_output(adata, vae, n_samples=30)
velocity_correlation.append(
get_velocity_correlation(
ground_truth=adata.layers["true_velocity"], estimated=adata.layers["velocity"], aggregation=np.mean
)
)
## calculate per gene correlation
time_corr = [
get_time_correlation(ground_truth=adata.obs["true_time"], estimated=adata.layers["fit_t"][:, i])
for i in range(adata.layers["fit_t"].shape[1])
]
time_correlation.append(np.mean(time_corr))
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Monitored metric elbo_validation did not improve in the last 45 records. Best score: -423.979. Signaling Trainer to stop.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Monitored metric elbo_validation did not improve in the last 45 records. Best score: -426.753. Signaling Trainer to stop.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Monitored metric elbo_validation did not improve in the last 45 records. Best score: -422.398. Signaling Trainer to stop.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Monitored metric elbo_validation did not improve in the last 45 records. Best score: -425.915. Signaling Trainer to stop.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Monitored metric elbo_validation did not improve in the last 45 records. Best score: -423.788. Signaling Trainer to stop.
if SAVE_DATA:
pd.DataFrame({"velocity": velocity_correlation, "time": time_correlation}).to_parquet(
path=DATA_DIR / DATASET / "results" / "velovi_correlation.parquet"
)