Tutorial 3: Train NicheTrans on SMA data
[1]:
import os, time, datetime, warnings
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
from model.nicheTrans_img import *
from datasets.data_manager_SMA import SMA
from utils.utils import *
from utils.utils_training_SMA import train, test
from utils.utils_dataloader import *
warnings.filterwarnings("ignore")
Initialize the args and fix seeds
[ ]:
%run ./args/args_SMA.py
args = args
set_seed(args.seed)
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_devices
print("==========\nArgs:{}\n==========".format(args))
==========
Args:Namespace(dropout_rate=0.1, eval_step=1, gamma=0.1, gpu_devices='0', img_size=256, lr=0.0003, max_epoch=40, msi_path='/home/wzk/ST_data/SMA_data/Processed_data_v4', n_source=3000, n_target=50, noise_rate=0.2, optimizer='adam', path_img='/home/wzk/ST_data/SMA_data/Processed/patches', rna_path='/home/wzk/ST_data/SMA_data/Zhikang', save_dir='./log', seed=1, stepsize=20, test_batch=32, train_batch=32, weight_decay=0.0005, workers=4)
==========
Initialize dataloaders and NicheTrans
[3]:
# create the dataloaders
dataset = SMA(path_img=args.path_img, rna_path=args.rna_path, msi_path=args.msi_path, n_top_genes=args.n_source, n_top_targets=args.n_target)
trainloader, testloader = sma_dataloader(args, dataset)
# create the model
source_dimension, target_dimension = dataset.rna_length, dataset.msi_length
model = NicheTrans(source_length=source_dimension, target_length=target_dimension, noise_rate=args.noise_rate, dropout_rate=args.dropout_rate)
model = nn.DataParallel(model).cuda()
------Calculating spatial graph...
The graph contains 12134 edges, 3120 cells.
3.8891 neighbors per cell on average.
------Calculating spatial graph...
The graph contains 24190 edges, 3120 cells.
7.7532 neighbors per cell on average.
------Calculating spatial graph...
The graph contains 11322 edges, 2918 cells.
3.8801 neighbors per cell on average.
------Calculating spatial graph...
The graph contains 22578 edges, 2918 cells.
7.7375 neighbors per cell on average.
------Calculating spatial graph...
The graph contains 10360 edges, 2675 cells.
3.8729 neighbors per cell on average.
------Calculating spatial graph...
The graph contains 20628 edges, 2675 cells.
7.7114 neighbors per cell on average.
=> SMA loaded
Dataset statistics:
------------------------------
subset | # num |
------------------------------
train | Without filtering 6038 spots from 2 slides
test | Without filtering 2675 spots from 1 slides
train | After filting 6005 spots from 2 slides
test | After filting 2655 spots from 1 slides
------------------------------
Initialize loss function (criterion) and optimizer
[4]:
criterion = nn.MSELoss()
if args.optimizer == 'adam':
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
elif args.optimizer == 'SGD':
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
else:
print('unexpected optimizer')
if args.stepsize > 0:
scheduler = lr_scheduler.StepLR(optimizer, step_size=args.stepsize, gamma=args.gamma)
Model training and testing
[ ]:
start_time = time.time()
for epoch in range(args.max_epoch):
last_epoch = epoch + 1 == args.max_epoch
print("==> Epoch {}/{}".format(epoch+1, args.max_epoch))
################
train(args, model, criterion, optimizer, trainloader, dataset.target_panel, use_img=False)
if args.stepsize > 0: scheduler.step()
if (epoch+1) % args.eval_step == 0:
pearson = test(args, model, testloader, dataset.target_panel, last_epoch, use_img=False)
if last_epoch==True:
torch.save(model.state_dict(), 'NicheTrans_SMA_last.pth')
################
elapsed = round(time.time() - start_time)
elapsed = str(datetime.timedelta(seconds=elapsed))
print("Finished. Total elapsed time (h:m:s): {}".format(elapsed))