cell2fate benchmark on cell cycle data#

Notebook benchmarks velocity, latent time inference, and cross boundary correctness using cell2fate on cell cycle data.

Note that cell2fate requires anndata==0.8.0 and scvi-tools==0.16.1.

Library imports#

import pandas as pd

import anndata as ad
import cell2fate as c2f
import scvelo as scv
from cellrank.kernels import VelocityKernel

from rgv_tools import DATA_DIR
from rgv_tools.benchmarking import get_time_correlation


DATASET = "cell_cycle"
STATE_TRANSITIONS = [("G1", "S"), ("S", "G2M")]
    (DATA_DIR / DATASET / "results").mkdir(parents=True, exist_ok=True)

Data loading#

using original count data to train cell2fate model

adata_raw = ad.read_h5ad(DATA_DIR / DATASET / "processed" / "adata.h5ad")
genes = ad.read_h5ad(DATA_DIR / DATASET / "processed" / "adata_processed.h5ad").var_names
umap = ad.read_h5ad(DATA_DIR / DATASET / "processed" / "adata_processed.h5ad").obsm["X_umap"].copy()
adata = adata_raw[:, genes].copy()
adata.obsm["X_umap"] = umap
AnnData object with n_obs × n_vars = 1146 × 395
    obs: 'phase', 'fucci_time'
    var: 'ensum_id'
    obsm: 'X_umap'
    layers: 'spliced', 'total', 'unspliced'

Velocity pipeline#

clusters_to_remove = []
adata.obs["clusters"] = "0"

adata = c2f.utils.get_training_data(
Keeping at most 100000 cells per cluster
/home/icb/weixu.wang/miniconda3/envs/cell2fate_env2/lib/python3.9/site-packages/scvelo/core/_anndata.py:622: ImplicitModificationWarning: Trying to modify attribute `.obs` of view, initializing view as actual.
  adata.obs[f"initial_size_{layer}"] = get_size(adata, layer)
Skip filtering by dispersion since number of variables are less than `n_top_genes`.
c2f.Cell2fate_DynamicalModel.setup_anndata(adata, spliced_label="spliced", unspliced_label="unspliced")
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
n_modules = c2f.utils.get_max_modules(adata)
Leiden clustering ...
WARNING: You’re trying to run this on 63 dimensions of `.X`, if you really want this, set `use_rep='X'`.
         Falling back to preprocessing with `sc.pp.pca` and default params.
/home/icb/weixu.wang/miniconda3/envs/cell2fate_env2/lib/python3.9/site-packages/scanpy/preprocessing/_simple.py:843: UserWarning: Received a view of an AnnData. Making a copy.
Number of Leiden Clusters: 13
Maximal Number of Modules: 14
mod = c2f.Cell2fate_DynamicalModel(adata, n_modules=n_modules)
/home/icb/weixu.wang/miniconda3/envs/cell2fate_env2/lib/python3.9/site-packages/pytorch_lightning/core/lightning.py:2054: DeprecationWarning: `torch.distributed._sharded_tensor` will be deprecated, use `torch.distributed._shard.sharded_tensor` instead
  from torch.distributed._sharded_tensor import pre_load_state_dict_hook, state_dict_hook
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
/home/icb/weixu.wang/miniconda3/envs/cell2fate_env2/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:120: UserWarning: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.
  rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.")
Set SLURM handle signals.
/home/icb/weixu.wang/miniconda3/envs/cell2fate_env2/lib/python3.9/site-packages/pytorch_lightning/trainer/data_loading.py:432: UserWarning: The number of training samples (2) is smaller than the logging interval Trainer(log_every_n_steps=10). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
Epoch 500/500: 100%|██████████| 500/500 [04:40<00:00,  1.78it/s, v_num=1, elbo_train=1.78e+6]
adata = mod.export_posterior(
    adata, sample_kwargs={"batch_size": None, "num_samples": 30, "return_samples": True, "use_gpu": False}
Sampling local variables, batch: 100%|██████████| 1/1 [00:08<00:00,  8.30s/it]
Sampling global variables, sample: 100%|██████████| 29/29 [00:04<00:00,  5.93it/s]
Warning: Saving ALL posterior samples. Specify "return_samples: False" to save just summary statistics.
mod.compute_and_plot_total_velocity(adata, delete=False)
Computing total RNAvelocity ...
/home/icb/weixu.wang/miniconda3/envs/cell2fate_env2/lib/python3.9/site-packages/cell2fate/_cell2fate_DynamicalModel.py:334: DeprecationWarning: `neighbors` is deprecated since scvelo==0.4.0 and will be removed in a future version of scVelo. Please compute neighbors with Scanpy.
  scv.pp.neighbors(adata, n_neighbors = n_neighbours)
/home/icb/weixu.wang/miniconda3/envs/cell2fate_env2/lib/python3.9/site-packages/scvelo/preprocessing/neighbors.py:233: DeprecationWarning: Automatic computation of PCA is deprecated since scvelo==0.4.0 and will be removed in a future version of scVelo. Please compute PCA with Scanpy first.
  _set_pca(adata=adata, n_pcs=n_pcs, use_highly_variable=use_highly_variable)
adata.layers["velocity"] = adata.layers["Velocity"].numpy()
adata.layers["Ms"] = adata.layers["spliced"].copy()
time_correlation = [get_time_correlation(ground_truth=adata.obs["fucci_time"], estimated=adata.obs["Time (hours)"])]
scv.tl.velocity_graph(adata, vkey="velocity", n_jobs=1)
scv.tl.velocity_confidence(adata, vkey="velocity")
computing velocity graph (using 1/112 cores)
    finished (0:00:03) --> added 
    'velocity_graph', sparse matrix with cosine correlations (adata.uns)
--> added 'velocity_length' (adata.obs)
--> added 'velocity_confidence' (adata.obs)
--> added 'velocity_confidence_transition' (adata.obs)

Cross boundary-correctness#

vk = VelocityKernel(adata, vkey="velocity", xkey="spliced").compute_transition_matrix()

cluster_key = "phase"
rep = "X_pca"

score_df = []
for source, target in STATE_TRANSITIONS:
    cbc = vk.cbc(source=source, target=target, cluster_key=cluster_key, rep=rep)

                "State transition": [f"{source} - {target}"] * len(cbc),
                "CBC": cbc,
score_df = pd.concat(score_df)

Data saving#

    pd.DataFrame({"time": time_correlation}, index=adata.obs_names).to_parquet(
        path=DATA_DIR / DATASET / "results" / "cell2fate_correlation.parquet"
    adata.obs[["velocity_confidence"]].to_parquet(path=DATA_DIR / DATASET / "results" / "cell2fate_confidence.parquet")
    score_df.to_parquet(path=DATA_DIR / DATASET / "results" / "cell2fate_cbc.parquet")