#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Apr 12 16:08:19 2019

@author: sayan
"""
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib
import math
import scipy.io as sio
from sklearn.metrics import accuracy_score
from scipy.io import loadmat
import shutil
import os
from net_train_joint import net_train_joint
sys.path.append("../pretraining/")
from model_random import Gene_ontology_network
from model_joint  import joint_network
from convert_to_gpu import gpu
from convert_to_gpu_and_tensor import gpu_t
from convert_to_gpu_scalar import gpu_ts
from convert_to_cpu import cpu
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from sklearn import metrics
import collections
import torch.nn.functional as F

#################################################################################################################################################
# necessary defintions

def initialize_check_losses(check_losses,s,sps_N, num, nn):
    for i in s:
        if i=='class_pred' or i=='class_true':
            check_losses[i] = -np.ones((sps_N,num, nn))
        else:
            check_losses[i] = np.zeros((sps_N,num))
    return 

def helper_test(net, x_g_j, x_n_j, x_s_j, y_j, x_g_n, x_n, y_n, x_g_s, x_s, y_s,  prob, D, typ, hyp, sp_N, e, metric, T):
    # model evaluation
    
    net.eval()
    with torch.no_grad():
            
        # when all modalities are present
        loss_j, loss1_j, loss2_j, loss3_j, loss4_j, loss4_j_c, loss4_j_d, y_j_hat = helper(net, x_g_j.detach(), x_n_j.detach(), x_s_j.detach(), y_j.detach(), prob, 'j', criterion_recon, criterion_class, hyp, T)

        # when nback present
        loss_n, loss1_n, loss2_n, loss3_n, loss4_n, loss4_n_c, loss4_n_d, y_n_hat = helper(net, x_g_n.detach(), x_n.detach(), 0, y_n.detach(), prob, 'N', criterion_recon, criterion_class, hyp, T)
        
        # when sdmt present
        loss_s, loss1_s, loss2_s, loss3_s, loss4_s, loss4_s_c, loss4_s_d, y_s_hat = helper(net, x_g_s.detach(), 0, x_s.detach(), y_s.detach(), prob, 'S', criterion_recon, criterion_class, hyp, T)
        
        # Store the losses and performances.
        
        D[typ]['gene_recon_loss'][sp_N,e]    = loss1_j   + loss1_n   + loss1_s
        D[typ]['nback_recon_loss'][sp_N,e]   = loss2_j   + loss2_n   + loss2_s
        D[typ]['sdmt_recon_loss'][sp_N,e]    = loss3_j   + loss3_n   + loss3_s
        D[typ]['class_loss'][sp_N,e]         = loss4_j   + loss4_n   + loss4_s
        D[typ]['control_class_loss'][sp_N,e] = loss4_j_c + loss4_n_c + loss4_s_c
        D[typ]['disease_class_loss'][sp_N,e] = loss4_j_d + loss4_n_d + loss4_s_d
        D[typ]['total_loss'][sp_N,e]         = loss_j    + loss_n    + loss_s
        
        Y     = torch.cat((y_j,y_n,y_s),dim=0)
        Y_hat = torch.cat((y_j_hat,y_n_hat,y_s_hat),dim=0)
        
        D[typ]['class_true'][sp_N,e,:] = cpu(Y.squeeze()).data.numpy()
        D[typ]['class_pred'][sp_N,e,:] = cpu(Y_hat.squeeze()).data.numpy()
        
    
        fpr_test, tpr_test, thresholds_test    = metrics.roc_curve( cpu(Y).data.numpy()  , cpu(Y_hat).data.numpy(), pos_label=1)
        metric[typ]['AUC'][sp_N,e]           = metrics.auc(fpr_test, tpr_test)
        
        y_te = [x>0.5 for x in cpu(Y_hat)]        
        metric[typ]['Acc'][sp_N,e] = accuracy_score(np.array(y_te).astype(int), np.array(cpu(Y).squeeze().tolist()))
        
    return
        

def helper_plot(M, e, sp_N, f, st, sps):
    
    for i in M:
        if i!='class_pred' and i!='class_true':
            plt.plot(M[i][sp_N,0:e].squeeze().tolist())
            plt.savefig('data_'+str(m_n+1)+'/model'+str(f)+'/'+st+'_'+i+'_'+str(sps[sp_N])+'.pdf')
            plt.close()  
    return

def train_helper_plot(M, e, sp_N, f, st):
    
    for i in M:
        plt.plot(M[i][0:e])
        plt.savefig('data_'+str(m_n+1)+'/model'+str(f)+'/'+st+'_'+i+'.pdf')
        plt.close()
      
                        

       

def helper(net, x_g, x_n, x_s, y, prob, mode, criterion_recon, criterion_class, hyp, T):
    net.eval()
    with torch.no_grad():
        imp_N     = gpu(F.softmax(net.bias_n[0], dim=1))
        imp_o_N   = imp_N[:,1]
        
        imp_S     = gpu(F.softmax(net.bias_s[0], dim=1))
        imp_o_S   = imp_S[:,1]
        
        # Here during validation we are multiplying the input features with the learned probabilities and then passing it through the model.
        
        x_n_t = x_n.detach()*torch.unsqueeze(imp_o_N,0) if (mode=='j' or mode=='N') else torch.tensor(0).float()
        x_s_t = x_s.detach()*torch.unsqueeze(imp_o_S,0) if (mode=='j' or mode=='S') else torch.tensor(0).float()
        
        if mode=='j':
            mode_in = gpu_ts(0)
        elif mode=='N':
            mode_in = gpu_ts(1)
        else:
            mode_in = gpu_ts(2)
        
        surrogate_ig, y_hat, _ = net(x_g, x_n_t, x_s_t, T, mode_in)
        
        loss1 = hyp[0]*torch.sum(criterion_recon(surrogate_ig[0], x_g))
        loss2 = hyp[1]*torch.sum(criterion_recon(surrogate_ig[1], x_n)) if (mode=='j' or mode=='N') else torch.tensor(0).float()
        loss3 = hyp[2]*torch.sum(criterion_recon(surrogate_ig[2], x_s)) if (mode=='j' or mode=='S') else torch.tensor(0).float()
        loss4 = hyp[3]*torch.sum(criterion_class(y_hat, y))
        
        loss = loss1 + loss2 + loss3 + loss4
        
        idx = np.nonzero(cpu(y).data.numpy())
        weight_d = torch.zeros(y.size())
        weight_d[idx] =  1
        weight_d = gpu(weight_d)
        weight_c = torch.ones(y.size())
        weight_c[idx] =  0
        weight_c = gpu(weight_c)        
        
        loss4_control    = hyp[3]*torch.sum(criterion_class(y_hat,y)*weight_c)
        loss4_disease    = hyp[3]*torch.sum(criterion_class(y_hat,y)*weight_d)
        
    return cpu(loss), cpu(loss1), cpu(loss2), cpu(loss3), cpu(loss4), cpu(loss4_control), cpu(loss4_disease), cpu(y_hat)

def save_results(net, metric_losses, check_losses, train_losses, f, prob_d):
    vis = {}
    
    for i in train_losses:
        vis['train'+i] = train_losses[i]
        
    for i in check_losses['test']:
        vis['test'+i] = check_losses['test'][i]
        vis['val' +i] = check_losses['val'][i]
        
    for i in metric_losses['test']:
        vis['train'+i] = metric_losses['test'][i]
        vis['test'+i] = metric_losses['test'][i]
        vis['val' +i]  = metric_losses['val'][i]
        
    vis['prob_Nback'] = prob_d['Nback']
    vis['prob_SDMT']  = prob_d['SDMT']
    
    sio.savemat('data_'+str(m_n+1)+'/setting/model'+str(f)+'/vis'+'.mat',vis)   
    
    
    

##############################################################################################################################################3

# Input from terminal, here m_n is the data distribution, id_fold is fold number of 10 fold CV 
idd = loadmat('incomplete_runs.mat')['arr']
idd = idd[0]
id_fold = math.floor(idd[int(sys.argv[1])]%10) # Fold number for each data distribution
m_n = idd[int(sys.argv[1])]//10 # Data distribution. In our case we randomly redistribute the data 10 times and performed 10 fold CV.
plt.ioff()

##################################################################################################################################################################################

# Loading data.

# training data

#Loading data for subjects who have all the 3 data modalities 
I_train_Nback_j = gpu_t(loadmat('../../data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['I_train_Nback_j'])
I_train_SDMT_j  = gpu_t(loadmat('../../data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['I_train_SDMT_j'])
G_train_j       = gpu_t(loadmat('../../data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['G_train_j'])
Y_train_j       = gpu_t(loadmat('../../data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['Y_train_j'])

#Loading data for subjects who have Nback data and SNP data. 
I_train_Nback   = gpu_t(loadmat('../../data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['I_train_Nback'])
I_train_SDMT    = gpu_t(loadmat('../../data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['I_train_SDMT'])
G_train_Nback   = gpu_t(loadmat('../../data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['G_train_Nback'])

#Loading data for subjects who have SDMT data and SNP data.
G_train_SDMT    = gpu_t(loadmat('../../data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['G_train_SDMT'])
Y_train_Nback   = gpu_t(loadmat('../../data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['Y_train_Nback'])
Y_train_SDMT    = gpu_t(loadmat('../../data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['Y_train_SDMT'])


## Validatio data

I_val_Nback_j = gpu_t(loadmat('../../data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['I_val_Nback_j'])
I_val_SDMT_j  = gpu_t(loadmat('../../data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['I_val_SDMT_j'])
G_val_j       = gpu_t(loadmat('../../data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['G_val_j'])
Y_val_j       = gpu_t(loadmat('../../data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['Y_val_j'])

I_val_Nback   = gpu_t(loadmat('../../data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['I_val_Nback'])
I_val_SDMT    = gpu_t(loadmat('../../data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['I_val_SDMT'])
G_val_Nback   = gpu_t(loadmat('../../data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['G_val_Nback'])
G_val_SDMT    = gpu_t(loadmat('../../data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['G_val_SDMT'])
Y_val_Nback   = gpu_t(loadmat('../../data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['Y_val_Nback'])
Y_val_SDMT    = gpu_t(loadmat('../../data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['Y_val_SDMT'])



# test data
I_test_Nback_j = gpu_t(loadmat('../../data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['I_test_Nback_j'])
I_test_SDMT_j  = gpu_t(loadmat('../../data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['I_test_SDMT_j'])
G_test_j       = gpu_t(loadmat('../../data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['G_test_j'])
Y_test_j       = gpu_t(loadmat('../../data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['Y_test_j'])

I_test_Nback   = gpu_t(loadmat('../../data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['I_test_Nback'])
I_test_SDMT    = gpu_t(loadmat('../../data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['I_test_SDMT'])
G_test_Nback   = gpu_t(loadmat('../../data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['G_test_Nback'])
G_test_SDMT    = gpu_t(loadmat('../../data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['G_test_SDMT'])
Y_test_Nback   = gpu_t(loadmat('../../data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['Y_test_Nback'])
Y_test_SDMT    = gpu_t(loadmat('../../data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['Y_test_SDMT'])

test_n = Y_test_j.size()[0] + Y_test_Nback.size()[0] + Y_test_SDMT.size()[0]
val_n  = Y_val_j.size()[0]  + Y_val_Nback.size()[0]  + Y_val_SDMT.size()[0]

# load the adjacency matrix of gene ontology A_p
A = loadmat('../../data/ontology_adjacency.mat')['A']
A = gpu(torch.tensor(A).float().t().to_sparse().coalesce() ) # This operation ensures that if A[i,j]=1, then there is a connection from j --> i

# load the binary matrix of gene scores to pathways A_g
A_g = loadmat('../../data/gene_to_pathways.mat')['mask_gene']
A_g = gpu(torch.tensor(A_g).float().to_sparse().coalesce() )

# This is an array which contains the number of nodes discarded after each layer of GCN. For example, pool_dim[0] = 1000 means after 1st GCN the first 0:999 nodes of the graph are discarded
pool_dim = loadmat('../../data/pool_layers.mat')['pool_n']
pool_dim = pool_dim.tolist()

############################################################################################################################################################################
e_losses = []

num_epochs = 1000 # maximum number of epochs

temp=10**10

num_s = 0
num_f = 5

iter_n = 25

# Thresholdin probabilities.
sps_N = [0]  

# Batch Size
batch_size = 32

# cost weights : lambda
lambda_0 = [gpu_ts(0.00001), gpu_ts(0.003), gpu_ts(0.003),gpu_ts(1),gpu_ts(1)] # gene_recon, nback recon, sdmt recon, class loss, sprasity

# sprsity probability
prob_ref = [gpu_ts(0.001), gpu_ts(0.001)]

# gumbell parameter
temperature = gpu_ts(0.1)

# latent dimension
l_dim = 50

#warm start iteration numbers
iter_ths = 0

# Define the loss functions
criterion_class = nn.BCELoss(reduction='none')
criterion_recon = nn.MSELoss(reduction='none') 

# Create folder to store models and results

if( not os.path.exists('data_'+str(m_n+1)) ):
    os.mkdir('data_'+str(m_n+1)) # Create cross validation folder f.

if( not os.path.exists('data_'+str(m_n+1)+'/setting') ):
    os.mkdir('data_'+str(m_n+1)+'/setting') # Create cross validation folder f.
    
  
for f in range(id_fold, id_fold+1):
    if( os.path.exists('data_'+str(m_n+1)+'/model'+str(f)) ):
        shutil.rmtree('data_'+str(m_n+1)+'/model'+str(f))
        
    print('fold - '+str(f))
    #shutil.rmtree('model'+str(f))
    os.mkdir('data_'+str(m_n+1)+'/model'+str(f)) # Create cross validation folder f.
    if  os.path.exists('data_'+str(m_n+1)+'/setting/model'+str(f)) :
        shutil.rmtree('data_'+str(m_n+1)+'/setting/model'+str(f))
        
    os.mkdir('data_'+str(m_n+1)+'/setting/model'+str(f)) # Create cross valida+ 0.1*loss2tion folder f.
    
    # Load pretrained model
    gene_ont = Gene_ontology_network(A_g, A, 2, 4, [5,5,5,5], pool_dim,l_dim)
    gene_ont.load_state_dict(torch.load('../../pretrained/models/model'+str(253)+'.pth'))
    
    # Couple modules
    net =  gpu(joint_network(gpu(gene_ont), 246, 50, 246, 50, l_dim))
    
    # init optimizer
    opt_j       = optim.Adam(net.parameters(), lr=0.0002, betas=(0.9, 0.999), weight_decay=0)
    scheduler_j = torch.optim.lr_scheduler.StepLR(opt_j, step_size=50, gamma = 0.7) # this will decrease the learning rate by factor of 0.1
    
    # initialize losses losses
    train_losses = {'total_loss':[], 'gene_recon_loss':[], 'nback_recon_loss':[], 'sdmt_recon_loss':[], 'class_loss':[], 'disease_class_loss':[], 'control_class_loss':[],'sparsity_loss':[]}
    check_losses = collections.defaultdict(dict)
    l = ['total_loss', 'gene_recon_loss', 'nback_recon_loss', 'sdmt_recon_loss', 'class_loss', 'control_class_loss', 'disease_class_loss', 'class_pred', 'class_true']
    check_losses['val']   = {}
    check_losses['test']  = {}
    initialize_check_losses(check_losses['val'] , l, len(sps_N), num_epochs, val_n)
    initialize_check_losses(check_losses['test'], l, len(sps_N), num_epochs, test_n)
    
    # Initialize classification performance metrics.
    
    metric_losses = collections.defaultdict(dict)
    l = ['Acc', 'AUC']
    metric_losses['train']   = {}
    metric_losses['test']    = {}
    metric_losses['val']     = {}
    initialize_check_losses(metric_losses['train'], l, len(sps_N), num_epochs, test_n)
    initialize_check_losses(metric_losses['val']  , l, len(sps_N), num_epochs, val_n)
    initialize_check_losses(metric_losses['test'] , l, len(sps_N), num_epochs, test_n)
    
    prob_data = {}
    prob_data['Nback'] = np.zeros((num_epochs,246))
    prob_data['SDMT']  = np.zeros((num_epochs,246))
    
    for e in range(num_epochs):
        print(e)
        
        # Train over subjcts who have all 3 data modalities
        
        l_j,   l_g_j,   l_i_j_N, l_i_j_S, l_c_j,  l_s_j,   l_j_control, l_j_disease, y_j_hat, y_j  = net_train_joint(net, G_train_j, I_train_Nback_j, I_train_SDMT_j, Y_train_j, \
                                                                                                    opt_j, batch_size, temperature, lambda_0, prob_ref, criterion_class, \
                                                                                                    criterion_recon, e, iter_n, 'j', iter_ths)
        # Train over subjcts who have Nback and SNP data modalities
        l_N_O, l_g_N_O, l_i_N_O, NA,     l_c_N_O, l_s_N_O, l_N_control, l_N_disease, y_N_hat, y_N  = net_train_joint(net, G_train_Nback, I_train_Nback, gpu_ts(0), Y_train_Nback, \
                                                                                                    opt_j, batch_size, temperature, lambda_0, prob_ref, criterion_class, \
                                                                                                    criterion_recon, e, iter_n, 'N', iter_ths)
        # Train over subjcts who have SDMT and SNP data modalities
        
        l_S_O, l_g_S_O, NA,     l_i_S_O, l_c_S_O, l_s_S_O, l_S_control, l_S_disease, y_S_hat, y_S  = net_train_joint(net, G_train_SDMT, gpu_ts(0), I_train_SDMT, Y_train_SDMT, \
                                                                                                    opt_j, batch_size, temperature, lambda_0, prob_ref, criterion_class, \
                                                                                                    criterion_recon, e, iter_n, 'S', iter_ths)
        
        scheduler_j.step()
        
        
        # store train losses
        train_losses['total_loss'].append(l_j + l_N_O + l_S_O)
        train_losses['nback_recon_loss'].append(l_i_j_N + l_i_N_O)
        train_losses['sdmt_recon_loss'].append(l_i_j_S + l_i_S_O)
        train_losses['gene_recon_loss'].append(l_g_j + l_g_N_O + l_g_S_O)
        train_losses['class_loss'].append(l_c_j + l_c_N_O + l_c_S_O)
        train_losses['disease_class_loss'].append(l_j_disease + l_N_disease + l_S_disease)
        train_losses['control_class_loss'].append(l_j_control + l_N_control + l_S_control)
        train_losses['sparsity_loss'].append(l_s_j+l_s_N_O+l_s_S_O)  
        
        # Training performance
        Y_tr     = torch.cat((y_j,y_N,y_S),dim=0)
        Y_hat_tr = torch.cat((y_j_hat,y_N_hat,y_S_hat),dim=0)
    
        fpr_test, tpr_test, thresholds_test          = metrics.roc_curve( cpu(Y_tr).data.numpy()  , cpu(Y_hat_tr).data.numpy(), pos_label=1)
        metric_losses['train']['AUC'][0,e]           = metrics.auc(fpr_test, tpr_test)
        
        y_te = [x>0.5 for x in cpu(Y_hat_tr)]        
        metric_losses['train']['Acc'][0,e] = accuracy_score(np.array(y_te).astype(int), np.array(cpu(Y_tr).squeeze().tolist()))
        
        
        # testing and validation
        try:
            for sp_N in range(len(sps_N)):
                net.eval()
                with torch.no_grad():
                    helper_test(net, G_test_j, I_test_Nback_j, I_test_SDMT_j, Y_test_j, G_test_Nback, I_test_Nback, Y_test_Nback, G_test_SDMT, I_test_SDMT, Y_test_SDMT, \
                            sps_N[sp_N], check_losses, 'test', lambda_0, sp_N, e, metric_losses, temperature)
                
                
                    helper_test(net, G_val_j, I_val_Nback_j, I_val_SDMT_j, Y_val_j, G_val_Nback, I_val_Nback, Y_val_Nback, G_val_SDMT, I_val_SDMT, Y_val_SDMT, \
                            sps_N[sp_N], check_losses, 'val', lambda_0, sp_N, e, metric_losses, temperature)
                    
                    # Here we start saving models after 50 epochs. 50 epochs are used for the burn-in period. We save a model if the validatio loss decreases.
                    if e>=50 and check_losses['val']['total_loss'][0,e] < temp:
                        temp = check_losses['val']['total_loss'][0,e]
                        torch.save(net.state_dict(), 'data_'+str(m_n+1)+'/model'+str(f)+'/model'+str(e)+'.pth')
                
                    prob_data['Nback'][e,:] = cpu(net.prob[0]).data.numpy()
                    prob_data['SDMT'][e,:]  = cpu(net.prob[1]).data.numpy()
                
                # Save results every 5 epochs,
                if e%5==0:
                    save_results(net, metric_losses, check_losses, train_losses, f, prob_data)
                
                # Plot results every num_f epochs.    
                if e%num_f == 0 :
                
                    helper_plot( metric_losses['test'], e, sp_N, f, 'test',sps_N)
                    helper_plot( check_losses['test'],  e, sp_N, f, 'test',sps_N)
                
                    helper_plot( metric_losses['val'], e, sp_N, f, 'val',sps_N)
                    helper_plot( check_losses['val'],  e, sp_N, f, 'val',sps_N)
                
                    train_helper_plot( train_losses,   e, sp_N, f, 'train')
                
                    gen_pred = cpu(net.prob[0]).data.numpy()
                    plt.stem(list(range(np.shape(gen_pred)[0])),gen_pred)
                    plt.savefig('data_'+str(m_n+1)+'/model'+str(f)+'/prob_pred_Nback.pdf')
                    plt.close()
                    
                    gen_pred = cpu(net.prob[1]).data.numpy()
                    plt.stem(list(range(np.shape(gen_pred)[0])),gen_pred)
                    plt.savefig('data_'+str(m_n+1)+'/model'+str(f)+'/prob_pred_SDMT.pdf')
                    plt.close()
        except:
            continue
                
            
            
        
