Integration with cellrank

In this notebook we introduce the interoperability between schist and cellrank. We will use one of the examples from cellrank (Zebrafish). Note that the number of clusters that will be generated by schist won’t necessarily be conserved if you run this same code. Note also that this notebook is using cellrank version 2.

First load some libraries that will be required.

import cellrank as cr
import scanpy as sc
import scvelo as scv
import schist as scs
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.pyplot import *
import scipy.stats as st
import scipy.spatial as ssp
import graph_tool as gt
import seaborn as sns
sc.set_figure_params()
rcParams['axes.grid'] = False
%matplotlib inline

We first download the Zebrafish dataset and perform some preprocessing as previously showed on cellrank tutorials

adata = cr.datasets.zebrafish()
sc.pp.filter_genes(adata, min_cells=10)
scv.pp.normalize_per_cell(adata)
sc.pp.log1p(adata)

sc.pp.highly_variable_genes(adata)

adata.layers["spliced"] = adata.X
adata.layers["unspliced"] = adata.X
n_neighbors = int(np.sqrt(adata.shape[0]/4))
scv.pp.moments(adata, n_pcs=30, n_neighbors=n_neighbors)

We then run a nested model in schist as usual. Plotting the modularity profile allows us to identify the most likely level of interest, level 2 in this case.

scs.inference.nested_model(adata, max_iter = np.inf)

plot(adata.uns['schist']['nsbm']['stats']['modularity'])
../_images/output_3_2.png

Remove the given color map for developmental stages and plot the hierarchy along with it.

CC = adata.uns.pop('Stage_colors')
sc.pl.embedding(adata, color=['Stage', 'nsbm_level_2',
                             'nsbm_level_3', 'nsbm_level_4'],
                basis='X_force_directed', ncols=2,
               legend_loc='on data', frameon=False)
../_images/output_5_1.png

We already appreciate that some groups identified by schist at level 2 correspond to the developmental stages annotated for this dataset. Now run cellrank as done in the tutorials, using the CytoTraceKernel.

ctk = cr.kernels.CytoTRACEKernel(adata)
ctk.compute_cytotrace()
ctk.compute_transition_matrix(threshold_scheme="soft", nu=0.5)
g_fwd = cr.estimators.GPCCA(ctk)
g_fwd.compute_schur(n_components=20)
g_fwd.plot_spectrum(real_only=True)
../_images/output_7_1.png

While there could be 5-6 macrostates, we select the top 3, as done in the original tutorials.

g_fwd.compute_macrostates(n_states=3, cluster_key="lineages")
g_fwd.plot_macrostates('all',
    discrete=True, legend_loc="right", size=100, basis="force_directed"
)
../_images/output_7_2.png

We highlight three out of the seven groups identified by schist which seem to correspond to the macrostates

scs_groups = ['6', '5', '4']
sc.pl.embedding(adata, color='nsbm_level_2', basis='X_force_directed', groups=scs_groups)
../_images/output_8_1.png

We now go on with processing of macrostates, identifying their fate probabilities and lineage drivers

g_fwd.predict_terminal_states()
g_fwd.compute_fate_probabilities()
g_fwd.plot_fate_probabilities(basis='force_directed')
cr_drivers = g_fwd.compute_lineage_drivers()
../_images/output_10_2.png

When schist infers the best models, it calculates cell marginals by default. These are the probabilities of each cell to be assigned to each group. We now want to see if we can use such marginals as lineage specifications, similar to fate probabilities by cellrank. To do so we instantiate a cr.Lineage object and use cellrank internal utilities to calculate lineage drivers.

scs_lineage = cr.Lineage(adata.obsm['CM_nsbm_level_2'],
                         names=adata.obs['nsbm_level_2'].cat.categories)
scs_drivers = cr._utils._utils._correlation_test(
    adata.X,
    scs_lineage,
    gene_names=adata.var_names,
    method=cr._utils._utils.TestMethod.FISHER,
    n_perms=1000,
    confidence_level=0.95,
)

Now let’s check if lineage drivers are consistent. We compare drivers using their computed correlation coefficient, given by the correlation tests above. For Blastomeres, matched to group 6, we obtain almost perfect match.

X = scs_drivers['6_corr']
Y = cr_drivers['Early Blastomeres_corr'].loc[X.index]
scatter(X, Y, s=1)
xlabel("6")
ylabel("Early Blastomeres")
rr = st.pearsonr(X, Y)
title(f"r={rr[0]:.3f} p={rr[1]:.3e}")
../_images/output_13_1.png

For the Prechordal Plate we obtain again fairly good results.

X = scs_drivers['5_corr']
Y = cr_drivers['Prechordal Plate_corr'].loc[X.index]
scatter(X, Y, s=1)
xlabel("5")
ylabel("Prechordal Plate")
rr = st.pearsonr(X, Y)
title(f"r={rr[0]:.3f} p={rr[1]:.3e}")
../_images/output_14_1.png

The situation for Notochord is a bit different. The gene scores for the two methods seem to be slightly different, as if there are two subgroups mixed.

X = scs_drivers['4_corr']
Y = cr_drivers['Notochord_corr'].loc[X.index]
scatter(X, Y, s=1)
xlabel("4")
ylabel("Notochord_corr")
rr = st.pearsonr(X, Y)
title(f"r={rr[0]:.3f} p={rr[1]:.3e}")
../_images/output_15_1.png

Plotting the actual probabilities makes clear that there is no complete match between the Notochord macrostate and group 4, the former being bigger and including more cells.

g_fwd.plot_fate_probabilities(same_plot=False, basis='force_directed')
../_images/output_16_1.png
for g in scs_lineage.names:
    adata.obs[f'CM_{g}'] = scs_lineage[g].X.squeeze()
scv.pl.scatter(adata, color=['CM_6', 'CM_5', 'CM_4'], basis='force_directed',
              cmap='viridis', perc=[2, 98])
../_images/output_17_1.png

Since the Schur decomposition revealed a higher number of macrostates, perform cellrank analysis with more of them.

g_fwd.compute_macrostates(n_states=6, cluster_key="lineages")
g_fwd.plot_macrostates('all', basis="force_directed", discrete=True)
../_images/output_24_1.png

The coarse grained transition matrix shows that 6 macrostates are totally legit in this dataset.

g_fwd.plot_coarse_T()
../_images/output_25_0.png

We proceed with the analysis, extracting lineage drivers for all 6 macrostates.

g_fwd.predict(stability_threshold=0.8)
g_fwd.compute_fate_probabilities()
g_fwd.plot_fate_probabilities(basis='force_directed')
cr_drivers = g_fwd.compute_lineage_drivers()
../_images/output_26_2.png

Let’s calculate and visualize all the pairwise correlations between drivers identified with both methods

A = scs_drivers.filter(like='_corr').sort_index()
B = cr_drivers.filter(like='_corr').sort_index()

DM = 1 - ssp.distance.cdist(A.T, B.T, metric='correlation')
DM = pd.DataFrame(DM,
                  columns=g_fwd.fate_probabilities.names,
                  index=scs_lineage.names)
sns.clustermap(DM, cmap='RdYlBu_r', vmin=-1, vmax=1, )
../_images/output_31_1.png

Allowing for more macrostates causes the “fragmentation” of the Notochord state and, in fact, the match with group 4 is now fairly consistent. Looking at the heatmap we also can spot a correspondence of group 0 with “Early_Blastomeres_3”, a group that wasn’t identified using only 3 macrostates.

X = scs_drivers['4_corr']
Y = cr_drivers['Notochord_corr'].loc[X.index]
scatter(X, Y, s=1)
xlabel("4")
ylabel("Notochord")
rr = st.pearsonr(X, Y)
title(f"r={rr[0]:.3f} p={rr[1]:.3e}")
../_images/output_32_1.png

Cell Trajectories

The following section is considered highly experimental and it is currently under study. schist (actually graph-tool) allows to estimate the affinity of each cell to their group by calculating the gain (or loss) of information that is obtained by moving a cell and putting it back to the original group. We can calculate this at every level of the hierarchy and use it as a proxy to define terminal states. We choose here level 1 (just below the one used for defining groups).

scs.tools.calculate_affinity(adata, level=1, back_prob=True)
M = adata.obsm['CA_nsbm_level_1']
E = np.exp(M)
adata.obs['scs_terminal_states'] = np.max(E, axis=1) / np.max(E)
sc.pl.embedding(adata, color='scs_terminal_states', basis='force_directed')
../_images/output_21_2.png

Interestingly the terminal states mostly correspond to the biologically relevant ones. As said, we can get the same at every level.

for k in adata.uns['schist']['nsbm']['blocks'].keys():
    scs.tools.calculate_affinity(adata, level=int(k), back_prob=True)
    M = adata.obsm[f'CA_nsbm_level_{k}']
    E = np.exp(M)
    adata.obs[f'scs_terminal_states_{k}'] = np.max(E, axis=1) / np.max(E)

To identify trajectories, we start from cells that have the lowest affinity, interpreting those as the ones that “wont’ stay” in a group (or are more likely to transition).

first_cell = adata.obs.sort_values('scs_terminal_states_2', ascending=False).index[0]
adata.uns['iroot'] = np.where(adata.obs_names == first_cell)[0][0]
sc.tl.diffmap(adata)

We then exploit the graph topology to travel from the first cell to the rest of the dataset, again recapitulating the main finding for this dataset (that is the differentiation from Blastomeres to Notochord and Prechordal Plate).

state = scs.tools.state_from_blocks(adata)
tour = gt.shortest_distance(state.g, source=adata.uns['iroot'])
adata.obs['shortest_path_dist'] = np.array(tour.a / np.max(tour.a))
sc.pl.embedding(adata, color='shortest_path_dist', basis='force_directed')
../_images/output_44_1.png

Lastly we can get the transition probabilities from the BlockState relative to level 2 and visualize it as a coarse grained matrix

M = state.get_levels()[2].get_matrix().A
sns.clustermap(M / np.sum(M, 1)[:, None] , cmap='viridis', fmt=".2f",
               annot=True, figsize=(6, 6), row_cluster=False, col_cluster=False)
../_images/output_58_1.png