#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Apr 12 16:08:19 2019

@author: sayan
"""
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib
import math
import scipy.io as sio
from sklearn.metrics import accuracy_score
from scipy.io import loadmat
import shutil
import os
from model_random import Gene_ontology_network
from model_joint  import joint_network
from convert_to_gpu import gpu
from convert_to_gpu_and_tensor import gpu_t
from convert_to_gpu_scalar import gpu_ts
from convert_to_cpu import cpu
#matplotlib.use('Agg')
import matplotlib.pyplot as plt
from sklearn import metrics
import collections
import torch.nn.functional as F
import shap

#################################################################################################################################################
# necessary defintions

def initialize_check_losses(check_losses,s,sps_N, num, nn):
    for i in s:
        if i=='class_pred' or i=='class_true':
            check_losses[i] = -np.ones((sps_N,num, nn))
        else:
            check_losses[i] = np.zeros((sps_N,num))
    return 

def helper_test(net, x_g_j, x_n_j, x_s_j, y_j, x_g_n, x_n, y_n, x_g_s, x_s, y_s,  prob, D, typ, hyp, sp_N, e, metric, T, K, shap_n, shap_s):
    net.eval()
    with torch.no_grad():
            
        # when all modalities are present
        loss_j, loss1_j, loss2_j, loss3_j, loss4_j, loss4_j_c, loss4_j_d, y_j_hat = helper(net, x_g_j.detach(), x_n_j.detach(), x_s_j.detach(), y_j.detach(), prob, 'j', criterion_recon, criterion_class, hyp, T, K,shap_n, shap_s)

        # when nback present
        loss_n, loss1_n, loss2_n, loss3_n, loss4_n, loss4_n_c, loss4_n_d, y_n_hat = helper(net, x_g_n.detach(), x_n.detach(), 0, y_n.detach(), prob, 'N', criterion_recon, criterion_class, hyp, T, K,shap_n, shap_s)
        
        # when sdmt present
        loss_s, loss1_s, loss2_s, loss3_s, loss4_s, loss4_s_c, loss4_s_d, y_s_hat = helper(net, x_g_s.detach(), 0, x_s.detach(), y_s.detach(), prob, 'S', criterion_recon, criterion_class, hyp, T, K,shap_n, shap_s)
        
        D[typ]['gene_recon_loss'][sp_N,e]    = loss1_j   + loss1_n   + loss1_s
        D[typ]['nback_recon_loss'][sp_N,e]   = loss2_j   + loss2_n   + loss2_s
        D[typ]['sdmt_recon_loss'][sp_N,e]    = loss3_j   + loss3_n   + loss3_s
        D[typ]['class_loss'][sp_N,e]         = loss4_j   + loss4_n   + loss4_s
        D[typ]['control_class_loss'][sp_N,e] = loss4_j_c + loss4_n_c + loss4_s_c
        D[typ]['disease_class_loss'][sp_N,e] = loss4_j_d + loss4_n_d + loss4_s_d
        D[typ]['total_loss'][sp_N,e]         = loss_j    + loss_n    + loss_s
        
        Y     = torch.cat((y_j,y_n,y_s),dim=0)
        Y_hat = torch.cat((y_j_hat,y_n_hat,y_s_hat),dim=0)
        
        D[typ]['class_true'][sp_N,e,:] = cpu(Y.squeeze()).data.numpy()
        D[typ]['class_pred'][sp_N,e,:] = cpu(Y_hat.squeeze()).data.numpy()
        
    
        fpr_test, tpr_test, thresholds_test    = metrics.roc_curve( cpu(Y).data.numpy()  , cpu(Y_hat).data.numpy(), pos_label=1)
        metric[typ]['AUC'][sp_N,e]           = metrics.auc(fpr_test, tpr_test)
        
        y_te = [x>0.5 for x in cpu(Y_hat)]        
        metric[typ]['Acc'][sp_N,e] = accuracy_score(np.array(y_te).astype(int), np.array(cpu(Y).squeeze().tolist()))
        
    return
        

def helper_plot(M, e, sp_N, f, st, sps):
    
    for i in M:
        if i!='class_pred' and i!='class_true':
            plt.plot(M[i][sp_N,0:e].squeeze().tolist())
            plt.savefig('data_'+str(m_n+1)+'/model'+str(f)+'/'+st+'_'+i+'_'+str(sps[sp_N])+'.pdf')
            plt.close()  
    return

def train_helper_plot(M, e, sp_N, f, st):
    
    for i in M:
        plt.plot(M[i][0:e])
        plt.savefig('data_'+str(m_n+1)+'/model'+str(f)+'/'+st+'_'+i+'.pdf')
        plt.close()
      
                        

       

def helper(net, x_g, x_n, x_s, y, prob, mode, criterion_recon, criterion_class, hyp, T, K, shap_n, shap_s):
    net.eval()
    with torch.no_grad():
        # identifies the top K SHAP featres.
        imp_o_N   = torch.tensor(shap_n)
        mask_n = torch.zeros(246)
        mask_n[torch.topk(imp_o_N, K).indices] = 1

        imp_o_S   = torch.tensor(shap_s)
        mask_s = torch.zeros(246)
        mask_s[torch.topk(imp_o_S, K).indices] = 1
        
        x_n_t = x_n.detach()*torch.unsqueeze(gpu(mask_n),0) if (mode=='j' or mode=='N') else torch.tensor(0).float()
        x_s_t = x_s.detach()*torch.unsqueeze(gpu(mask_s),0) if (mode=='j' or mode=='S') else torch.tensor(0).float()
        
        if mode=='j':
            mode_in = gpu_ts(0)
        elif mode=='N':
            mode_in = gpu_ts(1)
        else:
            mode_in = gpu_ts(2)
        
        surrogate_ig, y_hat, _ = net(x_g, x_n_t, x_s_t, T, mode_in)
        
        loss1 = hyp[0]*torch.sum(criterion_recon(surrogate_ig[0], x_g))
        loss2 = hyp[1]*torch.sum(criterion_recon(surrogate_ig[1], x_n)) if (mode=='j' or mode=='N') else torch.tensor(0).float()
        loss3 = hyp[2]*torch.sum(criterion_recon(surrogate_ig[2], x_s)) if (mode=='j' or mode=='S') else torch.tensor(0).float()
        loss4 = hyp[3]*torch.sum(criterion_class(y_hat, y))
        
        loss = loss1 + loss2 + loss3 + loss4
        
        idx = np.nonzero(cpu(y).data.numpy())
        weight_d = torch.zeros(y.size())
        weight_d[idx] =  1
        weight_d = gpu(weight_d)
        weight_c = torch.ones(y.size())
        weight_c[idx] =  0
        weight_c = gpu(weight_c)        
        
        loss4_control    = hyp[3]*torch.sum(criterion_class(y_hat,y)*weight_c)
        loss4_disease    = hyp[3]*torch.sum(criterion_class(y_hat,y)*weight_d)
        
    return cpu(loss), cpu(loss1), cpu(loss2), cpu(loss3), cpu(loss4), cpu(loss4_control), cpu(loss4_disease), cpu(y_hat)

def save_results(net, metric_losses, check_losses, train_losses, f, prob_d):
    vis = {}
    
    for i in train_losses:
        vis['train'+i] = train_losses[i]
        
    for i in check_losses['test']:
        vis['test'+i] = check_losses['test'][i]
        vis['val' +i] = check_losses['val'][i]
        
    for i in metric_losses['test']:
        vis['train'+i] = metric_losses['test'][i]
        vis['test'+i] = metric_losses['test'][i]
        vis['val' +i]  = metric_losses['val'][i]
        
    vis['prob_Nback'] = prob_d['Nback']
    vis['prob_SDMT']  = prob_d['SDMT']
    
    sio.savemat('data_'+str(m_n+1)+'/setting/model'+str(f)+'/vis'+'.mat',vis)   
    
    
    

#####################################################################################################

id_fold = math.floor(int(sys.argv[1])%10)
m_n = int(sys.argv[1])//10
plt.ioff()
L = loadmat('locations.mat')['L']
loc = L[m_n,id_fold]

##################################################################################################################################################################################

# Loading data.
# /mnt/sdb1/sayan/pathway_analysis
#/m/mnt/sdb1/sayan/pathway_analysisnt/sdb1/sayan/pathway_analysis
# 
# training data
I_train_Nback_j = gpu_t(loadmat('/mnt/sdb1/sayan/pathway_analysis/data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['I_train_Nback_j'])
I_train_SDMT_j  = gpu_t(loadmat('/mnt/sdb1/sayan/pathway_analysis/data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['I_train_SDMT_j'])
G_train_j       = gpu_t(loadmat('/mnt/sdb1/sayan/pathway_analysis/data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['G_train_j'])
Y_train_j       = gpu_t(loadmat('/mnt/sdb1/sayan/pathway_analysis/data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['Y_train_j'])

I_train_Nback   = gpu_t(loadmat('/mnt/sdb1/sayan/pathway_analysis/data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['I_train_Nback'])
I_train_SDMT    = gpu_t(loadmat('/mnt/sdb1/sayan/pathway_analysis/data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['I_train_SDMT'])
G_train_Nback   = gpu_t(loadmat('/mnt/sdb1/sayan/pathway_analysis/data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['G_train_Nback'])
G_train_SDMT    = gpu_t(loadmat('/mnt/sdb1/sayan/pathway_analysis/data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['G_train_SDMT'])
Y_train_Nback   = gpu_t(loadmat('/mnt/sdb1/sayan/pathway_analysis/data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['Y_train_Nback'])
Y_train_SDMT    = gpu_t(loadmat('/mnt/sdb1/sayan/pathway_analysis/data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['Y_train_SDMT'])


## Validatio data
#
I_val_Nback_j = gpu_t(loadmat('/mnt/sdb1/sayan/pathway_analysis/data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['I_val_Nback_j'])
I_val_SDMT_j  = gpu_t(loadmat('/mnt/sdb1/sayan/pathway_analysis/data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['I_val_SDMT_j'])
G_val_j       = gpu_t(loadmat('/mnt/sdb1/sayan/pathway_analysis/data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['G_val_j'])
Y_val_j       = gpu_t(loadmat('/mnt/sdb1/sayan/pathway_analysis/data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['Y_val_j'])

I_val_Nback   = gpu_t(loadmat('/mnt/sdb1/sayan/pathway_analysis/data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['I_val_Nback'])
I_val_SDMT    = gpu_t(loadmat('/mnt/sdb1/sayan/pathway_analysis/data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['I_val_SDMT'])
G_val_Nback   = gpu_t(loadmat('/mnt/sdb1/sayan/pathway_analysis/data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['G_val_Nback'])
G_val_SDMT    = gpu_t(loadmat('/mnt/sdb1/sayan/pathway_analysis/data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['G_val_SDMT'])
Y_val_Nback   = gpu_t(loadmat('/mnt/sdb1/sayan/pathway_analysis/data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['Y_val_Nback'])
Y_val_SDMT    = gpu_t(loadmat('/mnt/sdb1/sayan/pathway_analysis/data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['Y_val_SDMT'])



# test data
I_test_Nback_j = gpu_t(loadmat('/mnt/sdb1/sayan/pathway_analysis/data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['I_test_Nback_j'])
I_test_SDMT_j  = gpu_t(loadmat('/mnt/sdb1/sayan/pathway_analysis/data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['I_test_SDMT_j'])
G_test_j       = gpu_t(loadmat('/mnt/sdb1/sayan/pathway_analysis/data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['G_test_j'])
Y_test_j       = gpu_t(loadmat('/mnt/sdb1/sayan/pathway_analysis/data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['Y_test_j'])

I_test_Nback   = gpu_t(loadmat('/mnt/sdb1/sayan/pathway_analysis/data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['I_test_Nback'])
I_test_SDMT    = gpu_t(loadmat('/mnt/sdb1/sayan/pathway_analysis/data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['I_test_SDMT'])
G_test_Nback   = gpu_t(loadmat('/mnt/sdb1/sayan/pathway_analysis/data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['G_test_Nback'])
G_test_SDMT    = gpu_t(loadmat('/mnt/sdb1/sayan/pathway_analysis/data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['G_test_SDMT'])
Y_test_Nback   = gpu_t(loadmat('/mnt/sdb1/sayan/pathway_analysis/data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['Y_test_Nback'])
Y_test_SDMT    = gpu_t(loadmat('/mnt/sdb1/sayan/pathway_analysis/data/multiple_runs/train_test_val_imaging_data_folds_with_5_graph_layers'+str(id_fold+1)+'_3_'+str(m_n+1)+'.mat')['Y_test_SDMT'])

test_n = Y_test_j.size()[0] + Y_test_Nback.size()[0] + Y_test_SDMT.size()[0]
val_n  = Y_val_j.size()[0]  + Y_val_Nback.size()[0]  + Y_val_SDMT.size()[0]
#load mask.
A = loadmat('/mnt/sdb1/sayan/pathway_analysis/data/train_test_data_gcn_5_layers_joint_model.mat')['A']
A = gpu(torch.tensor(A).float().t().to_sparse().coalesce() )

A_g = loadmat('/mnt/sdb1/sayan/pathway_analysis/data/train_test_data_gcn_5_layers_joint_model.mat')['mask_gene']
A_g = gpu(torch.tensor(A_g).float().to_sparse().coalesce() )

pool_dim = loadmat('/mnt/sdb1/sayan/pathway_analysis/data/train_test_data_gcn_5_layers_joint_model.mat')['pool_n']
pool_dim = pool_dim.tolist()

############################################################################################################################################################################
e_losses = []

temp=10**10


# Thresholdin probabilities.
sps_N = [0]  

# Batch Size
batch_size = 32

# cost weights : lambda
lambda_0 = [gpu_ts(0.00001), gpu_ts(0.003), gpu_ts(0.003),gpu_ts(1),gpu_ts(1)] # gene_recon, nback recon, sdmt recon, class loss, sprasity

# gumbell parameter
temperature = gpu_ts(0.1)

l_dim = 50

#warm start iteration numbers
iter_ths = 0

# Define the loss functions
criterion_class = nn.BCELoss(reduction='none')
criterion_recon = nn.MSELoss(reduction='none') 

# Create folder to store models and results

if( not os.path.exists('data_'+str(m_n+1)) ):
    os.mkdir('data_'+str(m_n+1)) # Create cross validation folder f.

if( not os.path.exists('data_'+str(m_n+1)+'/setting') ):
    os.mkdir('data_'+str(m_n+1)+'/setting') # Create cross validation folder f.
    
  
for f in range(id_fold, id_fold+1):
    if( os.path.exists('data_'+str(m_n+1)+'/model'+str(f)) ):
        shutil.rmtree('data_'+str(m_n+1)+'/model'+str(f))
        
    print('fold - '+str(f))
    #shutil.rmtree('model'+str(f))
    os.mkdir('data_'+str(m_n+1)+'/model'+str(f)) # Create cross validation folder f.
    if  os.path.exists('data_'+str(m_n+1)+'/setting/model'+str(f)) :
        shutil.rmtree('data_'+str(m_n+1)+'/setting/model'+str(f))
        
    os.mkdir('data_'+str(m_n+1)+'/setting/model'+str(f)) # Create cross valida+ 0.1*loss2tion folder f.
    
    # Load models
    gene_ont = Gene_ontology_network(A_g, A, 2, 4, [5,5,5,5], pool_dim,l_dim)
   
    # Couple modules
    net =  gpu(joint_network(gpu(gene_ont), 246, 50, 246, 50, l_dim))
    net.load_state_dict(torch.load('../data_'+str(m_n+1)+'/model'+str(id_fold)+'/model'+str(loc)+'.pth'))
    
    
    # Shap Importance
    
    # All modalities are present
    
    I_train_j = torch.cat((I_train_Nback_j,I_train_SDMT_j),1) # Imaging Training data for subjects who have all the 3 data modalities.
    I_val_j = torch.cat((I_val_Nback_j,I_val_SDMT_j),1) # Validation data for subjects who have all the 3 data modalities.
    
    g_d = G_train_j.size()[1] # GeneticTraining data for subjects who have all the 3 data modalities.
    
    def f_j(X):
        net.eval()
        n =2*246# np.shape(X)[0]
        print(np.shape(X))
        num_s = np.shape(X)[0]
        if num_s != np.shape(T)[0]:
            surrogate_ig, y_hat, prob = net(G_val_j[n_val_j:n_val_j+1,:], gpu(torch.tensor(X[:,:n//2])), gpu(torch.tensor(X[:,n//2:])), 0.1, 0) # we are only geneating shap values for imaging features not for the genetic features, so we dont pass it through SHAP. 
        else:
            surrogate_ig, y_hat, prob = net(G_train_j, gpu(torch.tensor(X[:,:n//2])), gpu(torch.tensor(X[:,n//2:])), 0.1, 0)
           
        return cpu(y_hat).data.numpy()


    T = cpu(I_train_j).data.numpy()
    T_val = cpu(I_val_j).data.numpy() 
    shap_values_j=[]
    for n_val_j in range(np.shape(T_val)[0]):
        e = shap.KernelExplainer(f_j, T) # Initialize K-SHAP where the training data is used as background data
        shap_values_j.append(e.shap_values(T_val[n_val_j:n_val_j+1,:])) # Generate shap values for the imaging features for each subject who have 3 data modalities in validation set
    shap_values_j = np.concatenate(tuple(shap_values_j),0) # Concatenate shap values for the imaging features for all subject who have 3 data modalities in validation set
    
    
    ## Nback only
    def f_n(X):
        net.eval()
        n =2*246# np.shape(X)[0]
        print(np.shape(X))
        num_s = np.shape(X)[0]
        if num_s != np.shape(T)[0]:
            surrogate_ig, y_hat, prob = net(G_val_Nback[n_val_nback:n_val_nback+1,:], gpu(torch.tensor(X[:,:n//2])), gpu_ts(0), 0.1, 1)    
        else:
            surrogate_ig, y_hat, prob = net(G_train_Nback, gpu(torch.tensor(X[:,:n//2])), gpu_ts(0), 0.1, 1)
           
        return cpu(y_hat).data.numpy()


    T = cpu(I_train_Nback).data.numpy() # Imaging Training data for subjects who have all the Nback + SNP data modalities.
    T_val = cpu(I_val_Nback).data.numpy() # Imaging Validation data for subjects who have all the Nback + SNP data modalities.
    
    shap_values_nback = []
    for n_val_nback in range(np.shape(T_val)[0]):
        e = shap.KernelExplainer(f_n, T) # Initialize K-SHAP where the training data is used as background data
        shap_values_nback.append(e.shap_values(T_val[n_val_nback:n_val_nback+1,:]))  # Generate shap values for the imaging features for each subject who have Nback + SNP data modalities in validation set
    shap_values_nback = np.concatenate(tuple(shap_values_nback),0)    
    


    ## SDMT only
    def f_s(X):
        net.eval()
        n =2*246# np.shape(X)[0]
        print(np.shape(X))
        
        num_s = np.shape(X)[0]
        if num_s != np.shape(T)[0]:
            surrogate_ig, y_hat, prob = net(G_val_SDMT[n_val_sdmt:n_val_sdmt+1,:], gpu_ts(0), gpu(torch.tensor(X[:,:n//2])), 0.1, 2)    
        else:
            surrogate_ig, y_hat, prob = net(G_train_SDMT, gpu_ts(0), gpu(torch.tensor(X[:,:n//2])), 0.1, 2)
            
        return cpu(y_hat).data.numpy()


    T = cpu(I_train_SDMT).data.numpy()
    T_val = cpu(I_val_SDMT).data.numpy()
    shap_values_sdmt = []
    
    for n_val_sdmt in range(np.shape(T_val)[0]):
        e = shap.KernelExplainer(f_s, T)
        shap_values_sdmt.append(e.shap_values(T_val[n_val_sdmt:n_val_sdmt+1,:]))  # Generate shap values for the imaging features for each subject who have SDMT + SNP data modalities in validation set
    shap_values_sdmt = np.concatenate(tuple(shap_values_sdmt),0)   
    
    
    
    shap_nback = np.abs(np.median(np.concatenate((shap_values_j[0][:,:246], shap_values_nback[0]),0),0)) # concatenate the SHAP values obained over all the validation data and find its median to identify the important set of features for Nback.
    shap_sdmt  = np.abs(np.median(np.concatenate((shap_values_j[0][:,246:], shap_values_sdmt[0]),0),0)) # concatenate the SHAP values obained over all the validation data and find its median to identify the important set of features for SDMT.
    
    

    # initialize losses losses
    auc = []
    pred = np.zeros((len(range(5,246,5)),test_n))
    iterr=0
    for K in range(5,246,5):
        check_losses = collections.defaultdict(dict)
        l = ['total_loss', 'gene_recon_loss', 'nback_recon_loss', 'sdmt_recon_loss', 'class_loss', 'control_class_loss', 'disease_class_loss', 'class_pred', 'class_true']

        check_losses['test']  = {}

        initialize_check_losses(check_losses['test'], l, len(sps_N), 1, test_n)
    
        metric_losses = collections.defaultdict(dict)
        l = ['Acc', 'AUC']

        metric_losses['test']    = {}


        initialize_check_losses(metric_losses['test'] , l, len(sps_N), 1, test_n)
                    
        # testing and validation

        for sp_N in range(len(sps_N)):
            net.eval()
            with torch.no_grad():
                helper_test(net, G_test_j, I_test_Nback_j, I_test_SDMT_j, Y_test_j, G_test_Nback, I_test_Nback, Y_test_Nback, G_test_SDMT, I_test_SDMT, Y_test_SDMT, \
                            sps_N[sp_N], check_losses, 'test', lambda_0, sp_N, 0, metric_losses, temperature, K,shap_nback,shap_sdmt)
        pred[iterr,:] = np.squeeze(check_losses['test']['class_pred'][0,0,:])
        
        iterr +=1
        auc.append(metric_losses['test']['AUC'][0,0])
        
    true_class = np.squeeze(check_losses['test']['class_true'][0,0,:])
    vis = {'pred':pred, 'true_class':true_class,'AUC':auc, 'shap_nback':shap_nback,'shap_sdmt':shap_sdmt}
    sio.savemat('data_'+str(m_n+1)+'/model'+str(f)+'/vis'+'.mat',vis) 
        
