Early driver perturbation prediction#
Notebook for analyzing early drivers’ dynamics including nr2f5, sox9b, twist1b, and ets1.
Library imports#
import pandas as pd
import scipy
import matplotlib.pyplot as plt
import mplscience
import seaborn as sns
import cellrank as cr
import scanpy as sc
import scvi
from regvelo import REGVELOVI
from rgv_tools import DATA_DIR, FIG_DIR
from rgv_tools.benchmarking import set_output
from rgv_tools.perturbation import abundance_test, DEG, in_silico_block_simulation
from rgv_tools.plotting import bar_scores
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_csv from `anndata` is deprecated. Import anndata.io.read_csv instead.
warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_text from `anndata` is deprecated. Import anndata.io.read_text instead.
warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_excel from `anndata` is deprecated. Import anndata.io.read_excel instead.
warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_mtx from `anndata` is deprecated. Import anndata.io.read_mtx instead.
warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_loom from `anndata` is deprecated. Import anndata.io.read_loom instead.
warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_hdf from `anndata` is deprecated. Import anndata.io.read_hdf instead.
warnings.warn(msg, FutureWarning)
/home/icb/weixu.wang/miniconda3/envs/dynamo/lib/python3.10/site-packages/anndata/utils.py:429: FutureWarning: Importing read_umi_tools from `anndata` is deprecated. Import anndata.io.read_umi_tools instead.
warnings.warn(msg, FutureWarning)
General settings#
scvi.settings.seed = 0
[rank: 0] Seed set to 0
Constants#
DATASET = "zebrafish"
SAVE_DATA = True
if SAVE_DATA:
(DATA_DIR / DATASET / "processed").mkdir(parents=True, exist_ok=True)
(DATA_DIR / DATASET / "results").mkdir(parents=True, exist_ok=True)
SAVE_FIGURES = True
if SAVE_FIGURES:
(FIG_DIR / DATASET).mkdir(parents=True, exist_ok=True)
DRIVERS = ["nr2f5", "sox9b", "twist1b", "ets1"]
TERMINAL_STATES = [
"mNC_head_mesenchymal",
"mNC_arch2",
"mNC_hox34",
"Pigment",
]
Data loading#
adata = sc.read_h5ad(DATA_DIR / DATASET / "processed" / "adata_run_regvelo.h5ad")
CellRank pipeline#
vk = cr.kernels.VelocityKernel(adata)
vk.compute_transition_matrix()
ck = cr.kernels.ConnectivityKernel(adata).compute_transition_matrix()
kernel = 0.8 * vk + 0.2 * ck
estimator = cr.estimators.GPCCA(kernel)
# evaluate the fate prob on original space
estimator.compute_macrostates(n_states=8, cluster_key="cell_type")
estimator.set_terminal_states(TERMINAL_STATES)
estimator.compute_fate_probabilities()
df = estimator.compute_lineage_drivers(cluster_key="cell_type")
WARNING: Unable to import `petsc4py` or `slepc4py`. Using `method='brandts'`
WARNING: For `method='brandts'`, dense matrix is required. Densifying
2024-11-25 16:44:06,080 - INFO - Using pre-computed Schur decomposition
WARNING: Unable to import petsc4py. For installation, please refer to: https://petsc4py.readthedocs.io/en/stable/install.html.
Defaulting to `'gmres'` solver.
Plotting#
Gene expression#
with mplscience.style_context():
fig, ax = plt.subplots(1, 4, figsize=(20, 3))
for axis, gene in zip(ax, DRIVERS):
axis = sc.pl.umap(
adata, color=[gene], vmin="p1", vmax="p99", frameon=False, legend_fontsize=14, show=False, ax=axis
)
if SAVE_FIGURES:
fig.savefig(FIG_DIR / DATASET / "gene_expression.svg", format="svg", transparent=True, bbox_inches="tight")
plt.show()
order = ["3ss", "6-7ss", "10ss", "12-13ss", "17-18ss", "21-22ss"]
with mplscience.style_context():
sns.set(style="whitegrid")
fig, ax = plt.subplots(1, 4, figsize=(20, 3), sharey=True)
for axis, gene in zip(ax, DRIVERS):
axis = sc.pl.violin(
adata,
[gene],
groupby="stage",
stripplot=False, # remove the internal dots
inner="box", # adds a boxplot inside violins
order=order,
palette=["darkgrey"],
show=False,
ax=axis,
)
if SAVE_FIGURES:
fig.savefig(FIG_DIR / DATASET / "gene_expression_time.svg", format="svg", transparent=True, bbox_inches="tight")
plt.show()
coord = [[-0.3, 0.35, 0.25], [-3.5, 3.5, 2.5], [-3.5, 3.5, 2.5], [-0.9, 1, 0.7]]
for i, gene in enumerate(DRIVERS):
Gep = adata[:, gene].X.A
res = DEG(Gep, adata.obs["cell_type"].tolist())
res.index = res.loc[:, "cell_type"]
res.columns = ["gene", "cell_type", "coefficient", "pvalue"]
res = res.loc[["mNC_head_mesenchymal", "mNC_arch2", "mNC_hox34", "Pigment", "NPB_nohox"], :].copy()
bar_scores(
res,
adata,
"cell_type",
gene,
figsize=(2, 2),
title="DEG test",
min=coord[i][0],
max=coord[i][1],
loc=coord[i][2],
)
with mplscience.style_context():
sns.set(style="whitegrid")
if SAVE_FIGURES:
plt.savefig(FIG_DIR / DATASET / f"{gene}_DEG.svg", format="svg", transparent=True, bbox_inches="tight")
plt.show()
coord = [[-0.25, 0.5, 0.4], [-0.25, 0.5, 0.4], [-0.4, 0.5, 0.4], [-0.25, 0.5, 0.4]]
for i, gene in enumerate(DRIVERS):
Gep = adata[:, gene].X.A.reshape(-1)
score = []
pvalue = []
for i in range(adata.obsm["lineages_fwd"].shape[1]):
score.append(scipy.stats.pearsonr(pd.DataFrame(adata.obsm["lineages_fwd"]).iloc[:, i], Gep)[0])
pvalue.append(scipy.stats.pearsonr(pd.DataFrame(adata.obsm["lineages_fwd"]).iloc[:, i], Gep)[1])
test_result = pd.DataFrame({"coefficient": score, "pvalue": pvalue})
test_result.index = adata.obsm["lineages_fwd"].names.tolist()
test_result = test_result.loc[["mNC_head_mesenchymal", "mNC_arch2", "mNC_hox34", "Pigment"], :].copy()
bar_scores(test_result, adata, "cell_type", gene, figsize=(2, 2), min=coord[i][0], max=coord[i][1], loc=coord[i][2])
with mplscience.style_context():
sns.set(style="whitegrid")
if SAVE_FIGURES:
plt.savefig(FIG_DIR / DATASET / f"{gene}_cor.svg", format="svg", transparent=True, bbox_inches="tight")
plt.show()
Driver ranking#
for ts in adata.obsm["lineages_fwd"].names.tolist():
for gene in DRIVERS:
sns.histplot(df.loc[:, f"{ts}_corr"], color="skyblue", binwidth=0.05)
# Add a vertical line at x=0.5
plt.axvline(x=df.loc[gene, f"{ts}_corr"], color="red", linestyle="--")
# Add labels and title
plt.xlabel("Correlation")
plt.ylabel("Frequency")
plt.title(ts)
# Show plot
if SAVE_FIGURES:
plt.savefig(FIG_DIR / DATASET / f"{gene}_{ts}.svg", format="svg", transparent=True, bbox_inches="tight")
plt.show()
# Close the plot to free up memory
plt.close()
Applying RegVelo for perturbation prediction#
model = DATA_DIR / DATASET / "processed" / "rgv_model"
vae = REGVELOVI.load(model, adata)
set_output(adata, vae, n_samples=30, batch_size=adata.n_obs)
INFO File /ictstr01/home/icb/weixu.wang/regulatory_velo/data/zebrafish/processed/rgv_model/model.pt already
downloaded
vk = cr.kernels.VelocityKernel(adata)
vk.compute_transition_matrix()
ck = cr.kernels.ConnectivityKernel(adata).compute_transition_matrix()
kernel = 0.8 * vk + 0.2 * ck
estimator = cr.estimators.GPCCA(kernel)
## evaluate the fate prob on original space
estimator.compute_macrostates(n_states=8, cluster_key="cell_type")
estimator.set_terminal_states(TERMINAL_STATES)
estimator.compute_fate_probabilities()
estimator.plot_fate_probabilities(same_plot=False)
WARNING: Unable to import `petsc4py` or `slepc4py`. Using `method='brandts'`
WARNING: For `method='brandts'`, dense matrix is required. Densifying
2024-11-25 16:44:53,597 - INFO - Using pre-computed Schur decomposition
fate_prob_perturb = []
cand_list = ["ets1", "nr2f2", "nr2f5", "sox9b", "twist1a", "twist1b", "sox10", "mitfa", "tfec", "tfap2b"]
for TF in cand_list:
adata_target_perturb, reg_vae_perturb = in_silico_block_simulation(model, adata, TF)
n_states = 8
vk = cr.kernels.VelocityKernel(adata_target_perturb)
vk.compute_transition_matrix()
ck = cr.kernels.ConnectivityKernel(adata_target_perturb).compute_transition_matrix()
kernel = 0.8 * vk + 0.2 * ck
estimator = cr.estimators.GPCCA(kernel)
## evaluate the fate prob on original space
estimator.compute_macrostates(n_states=n_states, cluster_key="cell_type")
estimator.set_terminal_states(TERMINAL_STATES)
estimator.compute_fate_probabilities()
## visualize coefficient
cond1_df = pd.DataFrame(
adata_target_perturb.obsm["lineages_fwd"], columns=adata_target_perturb.obsm["lineages_fwd"].names.tolist()
)
fate_prob_perturb.append(cond1_df)
INFO File /ictstr01/home/icb/weixu.wang/regulatory_velo/data/zebrafish/processed/rgv_model/model.pt already
downloaded
WARNING: Unable to import `petsc4py` or `slepc4py`. Using `method='brandts'`
WARNING: For `method='brandts'`, dense matrix is required. Densifying
2024-11-25 16:45:02,645 - INFO - Using pre-computed Schur decomposition
INFO File /ictstr01/home/icb/weixu.wang/regulatory_velo/data/zebrafish/processed/rgv_model/model.pt already
downloaded
WARNING: Unable to import `petsc4py` or `slepc4py`. Using `method='brandts'`
WARNING: For `method='brandts'`, dense matrix is required. Densifying
2024-11-25 16:45:11,720 - INFO - Using pre-computed Schur decomposition
INFO File /ictstr01/home/icb/weixu.wang/regulatory_velo/data/zebrafish/processed/rgv_model/model.pt already
downloaded
WARNING: Unable to import `petsc4py` or `slepc4py`. Using `method='brandts'`
WARNING: For `method='brandts'`, dense matrix is required. Densifying
2024-11-25 16:45:20,519 - INFO - Using pre-computed Schur decomposition
INFO File /ictstr01/home/icb/weixu.wang/regulatory_velo/data/zebrafish/processed/rgv_model/model.pt already
downloaded
WARNING: Unable to import `petsc4py` or `slepc4py`. Using `method='brandts'`
WARNING: For `method='brandts'`, dense matrix is required. Densifying
2024-11-25 16:45:29,399 - INFO - Using pre-computed Schur decomposition
INFO File /ictstr01/home/icb/weixu.wang/regulatory_velo/data/zebrafish/processed/rgv_model/model.pt already
downloaded
WARNING: Unable to import `petsc4py` or `slepc4py`. Using `method='brandts'`
WARNING: For `method='brandts'`, dense matrix is required. Densifying
2024-11-25 16:45:38,035 - INFO - Using pre-computed Schur decomposition
INFO File /ictstr01/home/icb/weixu.wang/regulatory_velo/data/zebrafish/processed/rgv_model/model.pt already
downloaded
WARNING: Unable to import `petsc4py` or `slepc4py`. Using `method='brandts'`
WARNING: For `method='brandts'`, dense matrix is required. Densifying
2024-11-25 16:45:46,417 - INFO - Using pre-computed Schur decomposition
INFO File /ictstr01/home/icb/weixu.wang/regulatory_velo/data/zebrafish/processed/rgv_model/model.pt already
downloaded
WARNING: Unable to import `petsc4py` or `slepc4py`. Using `method='brandts'`
WARNING: For `method='brandts'`, dense matrix is required. Densifying
2024-11-25 16:45:55,109 - INFO - Using pre-computed Schur decomposition
INFO File /ictstr01/home/icb/weixu.wang/regulatory_velo/data/zebrafish/processed/rgv_model/model.pt already
downloaded
WARNING: Unable to import `petsc4py` or `slepc4py`. Using `method='brandts'`
WARNING: For `method='brandts'`, dense matrix is required. Densifying
2024-11-25 16:46:04,024 - INFO - Using pre-computed Schur decomposition
INFO File /ictstr01/home/icb/weixu.wang/regulatory_velo/data/zebrafish/processed/rgv_model/model.pt already
downloaded
WARNING: Unable to import `petsc4py` or `slepc4py`. Using `method='brandts'`
WARNING: For `method='brandts'`, dense matrix is required. Densifying
2024-11-25 16:46:13,294 - INFO - Using pre-computed Schur decomposition
INFO File /ictstr01/home/icb/weixu.wang/regulatory_velo/data/zebrafish/processed/rgv_model/model.pt already
downloaded
WARNING: Unable to import `petsc4py` or `slepc4py`. Using `method='brandts'`
WARNING: For `method='brandts'`, dense matrix is required. Densifying
2024-11-25 16:46:22,320 - INFO - Using pre-computed Schur decomposition
cond2_df = pd.DataFrame(adata.obsm["lineages_fwd"], columns=adata.obsm["lineages_fwd"].names.tolist())
df = []
for i in range(len(fate_prob_perturb)):
data = abundance_test(cond2_df, fate_prob_perturb[i])
data = pd.DataFrame(
{
"Score": data.iloc[:, 0].tolist(),
"p-value": data.iloc[:, 1].tolist(),
"Terminal state": data.index.tolist(),
"TF": [cand_list[i]] * (data.shape[0]),
}
)
df.append(data)
df = pd.concat(df)
df["Score"] = 0.5 - df["Score"]
# Create a DataFrame for easier plotting
with mplscience.style_context():
sns.set(style="whitegrid")
fig, ax = plt.subplots(figsize=(10, 3))
# sns.barplot(x='Terminal state', y='AUROC',data=data, hue = "Method",palette=pal,ax = ax)
color_label = "cell_type"
palette = dict(zip(adata.obs[color_label].cat.categories, adata.uns[f"{color_label}_colors"]))
subset_palette = {name: color for name, color in palette.items() if name in cond2_df.columns.tolist()}
sns.barplot(x="TF", y="Score", hue="Terminal state", data=df, ax=ax, palette=palette, dodge=True)
# Add vertical lines to separate groups
for i in range(len(df["TF"].unique()) - 1):
plt.axvline(x=i + 0.5, color="gray", linestyle="--")
# Label settings
plt.ylabel("Depletion score", fontsize=14)
plt.xlabel("TF", fontsize=14)
plt.xticks(fontsize=14) # Increase font size of x-axis tick labels
plt.yticks(fontsize=14) # Increase font size of y-axis tick labels
# Customize the legend
plt.legend(loc="lower center", fontsize=14, bbox_to_anchor=(0.5, -0.6), ncol=3)
if SAVE_FIGURES:
plt.savefig(
FIG_DIR / DATASET / "driver_perturbation_simulation_all.svg",
format="svg",
transparent=True,
bbox_inches="tight",
)
# Show the plot
plt.show()