Tutorial 5: Visualize results

[1]:
import os, warnings, torch

import torch.nn as nn
import scanpy as sc
import pandas as pd

from model.nicheTrans_img import *
from datasets.data_manager_STARmap_PLUS import AD_Mouse

from utils.utils import *
from utils.utils_dataloader import *
import matplotlib.pyplot as plt

warnings.filterwarnings("ignore")


from palettable.cartocolors.diverging import *
from palettable.scientific.diverging import *

Load dataset

[2]:
path = '/home/wzk/ST_data/AD_mouse2/norm/AD_mouses_adata/13months-disease-replicate_2_random.h5ad'
rna_adata = sc.read_h5ad(path)
target_file_name = '13months-disease-replicate_2_random.h5ad'

Load args

[3]:
%run ./args/args_STARmap_PLUS.py
args = args

Create dataloader

[4]:
# create the dataloaders
dataset = AD_Mouse(AD_adata_path=args.AD_adata_path, Wild_type_adata_path=args.Wild_type_adata_path, label_path=args.label_path, n_top_genes=args.n_top_genes)
trainloader, testloader = ad_mouse_dataloader(args, dataset)
------Calculating spatial graph...
The graph contains 124464 edges, 10372 cells.
12.0000 neighbors per cell on average.
------Calculating spatial graph...
The graph contains 115608 edges, 9634 cells.
12.0000 neighbors per cell on average.
------Calculating spatial graph...
The graph contains 96408 edges, 8034 cells.
12.0000 neighbors per cell on average.
=> AD Mouse loaded
Dataset statistics:
  ------------------------------
  subset   | # num |
  ------------------------------
  train    |  10372 spots, 894.0 positive tao, 291.0 positive plaque
  test     |   9634 spots, 620.0 positive tao, 195.0 positive plaque
  ------------------------------

Model initialization

[5]:
# create the model
source_dimension, target_dimension = dataset.rna_length, dataset.target_length
model = NicheTrans(source_length=source_dimension, target_length=target_dimension, noise_rate=args.noise_rate, dropout_rate=args.dropout_rate)
model = nn.DataParallel(model).cuda()

model.load_state_dict(torch.load('NicheTrans_STARmap_PLUS.pth'))
model.eval()
[5]:
DataParallel(
  (module): NicheTrans(
    (encoder): NetBlock(
      (noise_dropout): Dropout(p=0.5, inplace=False)
      (linear_list): ModuleList(
        (0): Linear(in_features=1719, out_features=512, bias=True)
        (1): Linear(in_features=512, out_features=256, bias=True)
      )
      (bn_list): ModuleList(
        (0): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (activation_list): ModuleList(
        (0-1): 2 x LeakyReLU(negative_slope=0.01)
      )
      (dropout_list): ModuleList(
        (0): Dropout(p=0.25, inplace=False)
      )
    )
    (fusion_omic): Self_Attention(
      (to_q): Sequential(
        (0): Linear(in_features=256, out_features=256, bias=False)
      )
      (to_k): Sequential(
        (0): Linear(in_features=256, out_features=256, bias=False)
      )
      (to_v): Linear(in_features=256, out_features=256, bias=False)
      (to_out): Linear(in_features=256, out_features=256, bias=True)
      (dropout): Dropout(p=0.25, inplace=False)
    )
    (ffn_omic): FeedForward(
      (net): Sequential(
        (0): Linear(in_features=256, out_features=1024, bias=True)
        (1): GEGLU()
        (2): Linear(in_features=512, out_features=256, bias=True)
        (3): Dropout(p=0.0, inplace=False)
      )
    )
    (ln1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (ln2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (predict_layers): ModuleList(
      (0-1): 2 x Sequential(
        (0): Linear(in_features=256, out_features=128, bias=True)
        (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): LeakyReLU(negative_slope=0.01)
        (3): Linear(in_features=128, out_features=1, bias=True)
      )
    )
    (non_linear): Sequential(
      (0): Linear(in_features=256, out_features=256, bias=True)
      (1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (dropout): Dropout(p=0.25, inplace=False)
    (dropout_5): Dropout(p=0.5, inplace=False)
  )
)

Model inference

[6]:
pd_value, gt_value = [], []

with torch.no_grad():
    for _, (rna, protein, cell, rna_neighbors, cell_neighbor, _) in enumerate(testloader):

        # rna, protein, cell, rna_neighbors, cell_neighbor, _

        rna, protein, rna_neighbors = rna.cuda(), protein.cuda(), rna_neighbors.cuda()
        source, target, source_neightbors = rna, protein, rna_neighbors

        outputs = model(source, source_neightbors)

        pd_value.append(outputs)
        gt_value.append(target)

pd_value = torch.cat(pd_value, dim=0).cpu().numpy()
gt_value = torch.cat(gt_value, dim=0).cpu().numpy()
[7]:
rna_adata.obs['pd_tau'] = (pd_value[:, 0] > 0.5) * 1
rna_adata.obs['pd_plaque'] = (pd_value[:, 1] > 0.5) * 1

rna_adata.obs['tau'] = (gt_value[:, 0] > 0.5) * 1
rna_adata.obs['plaque'] = (gt_value[:, 1] > 0.5) * 1

Model evaluation

[8]:
key_pd0, key_gt0 = 'pd_' + dataset.target_panel[0], dataset.target_panel[0]
key_pd1, key_gt1 = 'pd_' + dataset.target_panel[1], dataset.target_panel[1]


fig, ax = plt.subplots(1, figsize=(4, 4), dpi=200)
sc.pl.embedding(rna_adata, basis='spatial', ax=ax, show=False, s=5)

mask = rna_adata.obs[key_pd0] == 1
ax.scatter(rna_adata.obsm['spatial'][mask, 0], rna_adata.obsm['spatial'][mask, 1], color='green',  s=0.1)
mask = rna_adata.obs[key_pd1] == 1
ax.scatter(rna_adata.obsm['spatial'][mask, 0], rna_adata.obsm['spatial'][mask, 1], color='black', s=1)

plt.title('NicheTrans')

[8]:
Text(0.5, 1.0, 'NicheTrans')
_images/Tutorial_5:_Visualize_results_14_1.png
[9]:
fig, ax = plt.subplots(1, figsize=(4, 4), dpi=200)
sc.pl.embedding(rna_adata, basis='spatial', ax=ax, show=False, s=5)

mask = rna_adata.obs[key_gt0] == 1
ax.scatter(rna_adata.obsm['spatial'][mask, 0], rna_adata.obsm['spatial'][mask, 1], color='green',  s=0.1)
mask = rna_adata.obs[key_gt1] == 1
ax.scatter(rna_adata.obsm['spatial'][mask, 0], rna_adata.obsm['spatial'][mask, 1], color='black', s=1)

plt.title('Ground Truth')
[9]:
Text(0.5, 1.0, 'Ground Truth')
_images/Tutorial_5:_Visualize_results_15_1.png