CellOracle-based perturbation prediction#

Notebook for predicts TF perturbation effects with CellOracle.

Library imports#

import numpy as np
import pandas as pd

import celloracle as co
import scanpy as sc
import scvelo as scv
from celloracle.applications import Gradient_calculator

from rgv_tools import DATA_DIR
from rgv_tools.perturbation import (
    get_list_name,
    Multiple_TFScanning_perturbation_co,
    split_elements,
    TFScanning_perturbation_co,
)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_csv from `anndata` is deprecated. Import anndata.io.read_csv instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_text from `anndata` is deprecated. Import anndata.io.read_text instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_excel from `anndata` is deprecated. Import anndata.io.read_excel instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_mtx from `anndata` is deprecated. Import anndata.io.read_mtx instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_loom from `anndata` is deprecated. Import anndata.io.read_loom instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_hdf from `anndata` is deprecated. Import anndata.io.read_hdf instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_umi_tools from `anndata` is deprecated. Import anndata.io.read_umi_tools instead.
  warnings.warn(msg, FutureWarning)

Constants#

DATASET = "zebrafish"
SAVE_DATA = True
if SAVE_DATA:
    (DATA_DIR / DATASET / "results").mkdir(parents=True, exist_ok=True)
TERMINAL_STATES = [
    "mNC_head_mesenchymal",
    "mNC_arch2",
    "mNC_hox34",
    "Pigment",
]
single_ko = ["elk3", "erf", "fli1a", "mitfa", "nr2f5", "rarga", "rxraa", "smarcc1a", "tfec", "nr2f2"]
multiple_ko = ["fli1a_elk3", "mitfa_tfec", "tfec_mitfa_bhlhe40", "fli1a_erf_erfl3", "erf_erfl3"]

Data loading#

adata = sc.read_h5ad(DATA_DIR / DATASET / "raw" / "adata_zebrafish_preprocessed.h5ad")
df = pd.read_csv(DATA_DIR / DATASET / "raw" / "eRegulon_metadata_all.csv", index_col=0)

Data processing#

df.loc[:, "Target"] = [f"{a}*{b}" for a, b in zip(df.loc[:, "Region"].tolist(), df.loc[:, "Gene"].tolist())]
df = df.loc[:, ["TF", "Gene", "Target"]]

regulators = df["TF"].unique()
targets = df["Target"].unique()

# Create an empty binary matrix
binary_matrix = pd.DataFrame(0, columns=regulators, index=targets)

# Fill the binary matrix based on the relationships in the CSV file
for _, row in df.iterrows():
    binary_matrix.at[row["Target"], row["TF"]] = 1

original_list = binary_matrix.index.tolist()
peak = [item.split("*")[0] for item in original_list]
target = [item.split("*")[1] for item in original_list]

binary_matrix.loc[:, "peak_id"] = peak
binary_matrix.loc[:, "gene_short_name"] = target

binary_matrix = binary_matrix[
    ["peak_id", "gene_short_name"] + [col for col in binary_matrix if col not in ["peak_id", "gene_short_name"]]
]
binary_matrix = binary_matrix.reset_index(drop=True)
scv.pp.moments(adata, n_pcs=50, n_neighbors=30)
computing neighbors
    finished (0:00:03) --> added 
    'distances' and 'connectivities', weighted adjacency matrices (adata.obsp)
computing moments based on connectivities
    finished (0:00:00) --> added 
    'Ms' and 'Mu', moments of un/spliced abundances (adata.layers)

CellOracle pipeline#

adata.X = adata.layers["matrix"].copy()
oracle = co.Oracle()
oracle.import_anndata_as_raw_count(adata=adata, cluster_column_name="cell_type", embedding_name="X_umap")

oracle.import_TF_data(TF_info_matrix=binary_matrix)
8012 genes were found in the adata. Note that Celloracle is intended to use around 1000-3000 genes, so the behavior with this number of genes may differ from what is expected.
oracle.perform_PCA()
n_comps = np.where(np.diff(np.diff(np.cumsum(oracle.pca.explained_variance_ratio_)) > 0.002))[0][0]
n_comps = min(n_comps, 50)
n_cell = oracle.adata.shape[0]
print(f"cell number is :{n_cell}")

k = int(0.025 * n_cell)
print(f"Auto-selected k is :{k}")
cell number is :697
Auto-selected k is :17
oracle.knn_imputation(n_pca_dims=n_comps, k=k, balanced=True, b_sight=k * 8, b_maxl=k * 4, n_jobs=4)
links = oracle.get_links(cluster_name_for_GRN_unit="cell_type", alpha=10, verbose_level=10)
Inferring GRN for NC_trunk...
Inferring GRN for NPB_hox3...
Inferring GRN for NPB_nohox...
Inferring GRN for Pigment...
Inferring GRN for dNC_hox34...
Inferring GRN for dNC_hoxa2b...
Inferring GRN for dNC_nohox...
Inferring GRN for mNC_arch1...
Inferring GRN for mNC_arch2...
Inferring GRN for mNC_head_mesenchymal...
Inferring GRN for mNC_hox34...
Inferring GRN for mNC_nohox...
Inferring GRN for mNC_vagal...
links.filter_links(p=0.001, weight="coef_abs", threshold_number=2000)
links.get_network_score()
links.filter_links()
oracle.get_cluster_specific_TFdict_from_Links(links_object=links)
oracle.fit_GRN_for_simulation(alpha=10, use_cluster_specific_TFdict=True)

Reference vector field definition#

## calculate pseudotime
scv.tl.recover_dynamics(adata, var_names=adata.var_names, n_jobs=4)
scv.tl.velocity(adata, mode="dynamical")
scv.tl.latent_time(adata, min_likelihood=None)
recovering dynamics (using 4/128 cores)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_csv from `anndata` is deprecated. Import anndata.io.read_csv instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_text from `anndata` is deprecated. Import anndata.io.read_text instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_excel from `anndata` is deprecated. Import anndata.io.read_excel instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_mtx from `anndata` is deprecated. Import anndata.io.read_mtx instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_loom from `anndata` is deprecated. Import anndata.io.read_loom instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_csv from `anndata` is deprecated. Import anndata.io.read_csv instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_hdf from `anndata` is deprecated. Import anndata.io.read_hdf instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_text from `anndata` is deprecated. Import anndata.io.read_text instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_excel from `anndata` is deprecated. Import anndata.io.read_excel instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_mtx from `anndata` is deprecated. Import anndata.io.read_mtx instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_loom from `anndata` is deprecated. Import anndata.io.read_loom instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_hdf from `anndata` is deprecated. Import anndata.io.read_hdf instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_csv from `anndata` is deprecated. Import anndata.io.read_csv instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_text from `anndata` is deprecated. Import anndata.io.read_text instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_excel from `anndata` is deprecated. Import anndata.io.read_excel instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_mtx from `anndata` is deprecated. Import anndata.io.read_mtx instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_loom from `anndata` is deprecated. Import anndata.io.read_loom instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_hdf from `anndata` is deprecated. Import anndata.io.read_hdf instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_csv from `anndata` is deprecated. Import anndata.io.read_csv instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_text from `anndata` is deprecated. Import anndata.io.read_text instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_excel from `anndata` is deprecated. Import anndata.io.read_excel instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_mtx from `anndata` is deprecated. Import anndata.io.read_mtx instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_loom from `anndata` is deprecated. Import anndata.io.read_loom instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_hdf from `anndata` is deprecated. Import anndata.io.read_hdf instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_umi_tools from `anndata` is deprecated. Import anndata.io.read_umi_tools instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_umi_tools from `anndata` is deprecated. Import anndata.io.read_umi_tools instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_umi_tools from `anndata` is deprecated. Import anndata.io.read_umi_tools instead.
  warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_umi_tools from `anndata` is deprecated. Import anndata.io.read_umi_tools instead.
  warnings.warn(msg, FutureWarning)
    finished (0:04:24) --> added 
    'fit_pars', fitted parameters for splicing dynamics (adata.var)
computing velocities
    finished (0:00:03) --> added 
    'velocity', velocity vectors for each individual cell (adata.layers)
computing velocity graph (using 1/128 cores)
    finished (0:00:00) --> added 
    'velocity_graph', sparse matrix with cosine correlations (adata.uns)
computing terminal states
    identified 0 region of root cells and 1 region of end points .
    finished (0:00:00) --> added
    'root_cells', root cells of Markov diffusion process (adata.obs)
    'end_points', end points of Markov diffusion process (adata.obs)
WARNING: No root cells detected. Consider specifying root cells to improve latent time prediction.
computing latent time using root_cells as prior
    finished (0:00:01) --> added 
    'latent_time', shared time (adata.obs)
## use the velocity latent time inferred by scVelo to create gradient field
n_grid = 40
min_mass = 1.5
oracle.adata.obs["Pseudotime"] = adata.obs["latent_time"].copy()
gradient = Gradient_calculator(oracle_object=oracle, pseudotime_key="Pseudotime")
gradient.calculate_p_mass(smooth=0.8, n_grid=n_grid, n_neighbors=30)
gradient.calculate_mass_filter(min_mass=min_mass, plot=True)
gradient.transfer_data_into_grid(args={"method": "polynomial", "n_poly": 3}, plot=True)
gradient.calculate_gradient()
scale_dev = 40
gradient.visualize_results(scale=scale_dev, s=5)
cell_idx_pigment = np.where(oracle.adata.obs["cell_type"].isin(["Pigment"]))[0]

cell_idx_hox34 = np.where(oracle.adata.obs["cell_type"].isin(["mNC_hox34"]))[0]

cell_idx_arch2 = np.where(oracle.adata.obs["cell_type"].isin(["mNC_arch2"]))[0]

cell_idx_mesenchymal = np.where(oracle.adata.obs["cell_type"].isin(["mNC_head_mesenchymal"]))[0]
index_dictionary = {
    "Pigment": cell_idx_pigment,
    "mNC_hox34": cell_idx_hox34,
    "mNC_arch2": cell_idx_arch2,
    "mNC_head_mesenchymal": cell_idx_mesenchymal,
}

Perturbation prediction#

single knock-out#

n_neighbors = 30
single_ko = set(single_ko).intersection(adata.var_names)
single_ko = list(single_ko)
## celloracle perturbation
d = TFScanning_perturbation_co(
    adata, 8, "cell_type", TERMINAL_STATES, single_ko, oracle, gradient, index_dictionary, n_neighbors
)
2024-11-24 16:22:56,120 - INFO - Done fli1a
2024-11-24 16:23:21,777 - INFO - Done mitfa
2024-11-24 16:23:47,753 - INFO - Done nr2f2
2024-11-24 16:24:12,948 - INFO - Done rarga
2024-11-24 16:24:39,190 - INFO - Done tfec
2024-11-24 16:25:04,528 - INFO - Done rxraa
2024-11-24 16:25:30,067 - INFO - Done erf
2024-11-24 16:25:55,981 - INFO - Done elk3
2024-11-24 16:26:21,029 - INFO - Done nr2f5
coef = pd.DataFrame(np.array(d["coefficient"]))
coef.index = d["TF"]
coef.columns = get_list_name(d["coefficient"][0])

multiple knock-out#

multiple_ko_list = split_elements(multiple_ko)
d = Multiple_TFScanning_perturbation_co(
    adata, 8, "cell_type", TERMINAL_STATES, multiple_ko_list, oracle, gradient, index_dictionary, n_neighbors
)
2024-11-24 16:26:47,444 - INFO - Done fli1a_elk3
2024-11-24 16:27:13,422 - INFO - Done mitfa_tfec
2024-11-24 16:27:40,270 - INFO - Done tfec_mitfa_bhlhe40
2024-11-24 16:28:06,428 - INFO - Done fli1a_erf_erfl3
2024-11-24 16:28:32,149 - INFO - Done erf_erfl3
coef_multiple = pd.DataFrame(np.array(d["coefficient"]))
coef_multiple.index = d["TF"]
coef_multiple.columns = get_list_name(d["coefficient"][0])

Save dataset#

if SAVE_DATA:
    coef.to_csv(DATA_DIR / DATASET / "results" / "celloracle_perturb_single.csv")
    coef_multiple.to_csv(DATA_DIR / DATASET / "results" / "celloracle_perturb_multiple.csv")