"""

"""
import sys
import torch
import torch.nn as nn
import torch.nn.functional as Fun
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import matplotlib
import math
import pandas as pd
import scipy.io as sio
from sklearn.metrics import accuracy_score
from scipy.io import loadmat
import shutil
import os

from net_train import net_train

from net_test import net_test

from model_random import Gene_ontology_network


from convert_to_gpu import gpu
from convert_to_gpu_and_tensor import gpu_t
from convert_to_gpu_scalar import gpu_ts
from convert_to_cpu import cpu
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from sklearn import metrics
plt.ioff()


id_fold    = math.floor(int(sys.argv[1])%10) # id_fold is one of the random sub samples of the 2056 subjects who has genetics data.

e_losses   = []
l_dim      = 50
num_epochs = 501
num_s      = 0
num_f      = 5
iter_n     = 20
temperature = gpu_ts(0.1)    
lambda0 = gpu_ts(0.00001)
prob_ref = [gpu_ts(0.01)]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

######################################################################################################################################################################################
## Loading training data.
gene_train_l1 = loadmat('../../data/multiple_pretrain/train_test_data_gcn_5_layers_joint_model'+str(id_fold+1)+'.mat')['gene_train']
gene_train    = gpu(torch.tensor(gene_train_l1).float())

#load valdation data.
gene_test_l1 = loadmat('../../data/multiple_pretrain/train_test_data_gcn_5_layers_joint_model'+str(id_fold+1)+'.mat')['gene_test']
gene_test    = gpu(torch.tensor(gene_test_l1).float())

#load adjacency of ontology graph.
A = loadmat('../../data/multiple_pretrain/train_test_data_gcn_5_layers_joint_model'+str(id_fold+1)+'.mat')['A']
A = gpu(torch.tensor(A).float().t().to_sparse().coalesce() )

#load genescore to pathway embedding binary matrix.
A_g = loadmat('../../data/multiple_pretrain/train_test_data_gcn_5_layers_joint_model'+str(id_fold+1)+'.mat')['mask_gene']
A_g = gpu(torch.tensor(A_g).float().to_sparse().coalesce() )


Y_train = loadmat('../../data/multiple_pretrain/train_test_data_gcn_5_layers_joint_model'+str(id_fold+1)+'.mat')['class_train']
Y_test = loadmat('../../data/multiple_pretrain/train_test_data_gcn_5_layers_joint_model'+str(id_fold+1)+'.mat')['class_test']

#Pool_n contains the number of nodes in each layer of ontology.
pool_dim = loadmat('../../data/multiple_pretrain/train_test_data_gcn_5_layers_joint_model'+str(id_fold+1)+'.mat')['pool_n']
pool_dim = pool_dim.tolist()

Y_train = gpu(torch.tensor(Y_train).float())
Y_test = gpu(torch.tensor(Y_test).float())

############################################################################################################################################################################################################################################################################################################################################################################

# Thresholdin probabilities.
sps_g = [0] 

# Batch Size
batch_size = 113

# Create folder to store models and results
if( not os.path.exists('pretrained'+str(id_fold+1)) ):
    os.mkdir('pretrained'+str(id_fold+1))
    
if( not os.path.exists('pretrained'+str(id_fold+1)+'/setting') ):
    os.mkdir('pretrained'+str(id_fold+1)+'/setting') # Create cross validation folder f.
  
# Create folders to store the results and the models.  
for f in range(0, 1):
    if( os.path.exists('pretrained'+str(id_fold+1)+'/model'+str(f)) ):
        shutil.rmtree('pretrained'+str(id_fold+1)+'/model'+str(f))
    
    
    if( os.path.exists('pretrained'+str(id_fold+1)+'/models')):
        shutil.rmtree('pretrained'+str(id_fold+1)+'/models')
        
    print('fold - '+str(f))
    #shutil.rmtree('model'+str(f))
    os.mkdir('pretrained'+str(id_fold+1)+'/model'+str(f)) # Create cross validation folder f.
    os.mkdir('pretrained'+str(id_fold+1)+'/models')
    
    if  os.path.exists('pretrained'+str(id_fold+1)+'/setting/model'+str(f)) :
        shutil.rmtree('pretrained'+str(id_fold+1)+'/setting/model'+str(f))
        
    os.mkdir('pretrained'+str(id_fold+1)+'/setting/model'+str(f)) # Create cross valida+ 0.1*loss2tion folder f.
 
    # network dimensions
    net = Gene_ontology_network(A_g, A, 2, 4, [5,5,5,5], pool_dim,l_dim)
    net = net.to(device)
    
    # Initialize optimizers

    
    opt = optim.Adam(net.parameters(), lr=0.0001, betas=(0.9, 0.999),weight_decay=0)

    
    scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=50, gamma = 0.7) # this will decrease the learning rate by factor of 0.1

    # Define the loss functions
    criterion_class = nn.BCELoss(reduction='none')
    criterion_recon = nn.MSELoss(reduction='none')
    e_losses = []
    e_losses_c = []
    e_losses_d = []
    e_losses_reg = []
    e_losses_class =[]
     
    e_losses_test = np.zeros((len(sps_g),num_epochs))
    e_recon_test = np.zeros((len(sps_g),num_epochs))
    e_class_test = np.zeros((len(sps_g),num_epochs))
    e_losses_test_con = np.zeros((len(sps_g),num_epochs))
    e_losses_test_dis = np.zeros((len(sps_g),num_epochs))
    e_recon_test_con = np.zeros((len(sps_g),num_epochs))
    e_recon_test_dis = np.zeros((len(sps_g),num_epochs))
  
    acc_s     = 0*np.ones(len(sps_g))
    auc_s     = 0*np.ones(len(sps_g))
    it_model     = np.zeros(len(sps_g))
    accuracy_test  = np.zeros((len(sps_g),3000))
    accuracy_cv  = np.zeros((len(sps_g),3000))
    auc_test  = np.zeros((len(sps_g),3000))
    auc_cv  = np.zeros((len(sps_g),3000))
    auc_train  = np.zeros((len(sps_g),3000))
    accuracy_train  = np.zeros((len(sps_g),3000))
    tt = 0
    interv= 1
    for e in range(num_epochs):
        print(e)
        ll, l_class, l_reg, ll_c, ll_d, yy_train, Y_train_t = net_train(net, gene_train, Y_train, opt, batch_size, temperature, lambda0, prob_ref, criterion_class,criterion_recon, e,iter_n)
        
        e_losses_c.append(ll_c)
        e_losses_d.append(ll_d)
        e_losses.append(ll)
        e_losses_reg.append(l_reg)
        e_losses_class.append(l_class)
        
        scheduler.step()
        if e%interv == 0:
            
            for sp_g in range(len(sps_g)):

                ll_test, y_test, imp_map, ll_class, ll_test_con,ll_test_dis, ll_recon, ll_recon_con, ll_recon_dis = net_test(net, gene_test, Y_test, temperature, criterion_class, criterion_recon, lambda0)
                
                e_losses_test[sp_g,tt] = ll_test
                e_recon_test[sp_g,tt] = ll_recon
                e_class_test[sp_g,tt] = ll_class
                
                e_losses_test_con[sp_g,tt] = ll_test_con
                e_losses_test_dis[sp_g,tt] = ll_test_dis
                e_recon_test_con[sp_g,tt] = ll_recon_con
                e_recon_test_dis[sp_g,tt] = ll_recon_dis

                
            
                # Save model for model selection
                if e>= num_s and e%interv == 0:
                    
                    fpr_test, tpr_test, thresholds_test    = metrics.roc_curve( cpu(Y_test).data.numpy()  , cpu(y_test).data.numpy(), pos_label=1)  
#                    fpr_cv, tpr_cv, thresholds_cv    = metrics.roc_curve( cpu(Y_cv).data.numpy()  , cpu(y_cv).data.numpy(), pos_label=1) 
                    fpr_train, tpr_train, thresholds_train = metrics.roc_curve( cpu(Y_train_t).data.numpy() , cpu(yy_train).data.numpy(), pos_label=1)    
                
                
                    auc_test[sp_g,tt]  =  metrics.auc(fpr_test, tpr_test)
                    

                
                    auc_train[sp_g,tt] =  metrics.auc(fpr_train, tpr_train)
                
                    y_te = [x>0.5 for x in cpu(y_test)]        
                    accuracy_test[sp_g,tt] = accuracy_score(np.array(y_te).astype(int),np.array(cpu(Y_test).squeeze().tolist()))
                
                    y_tr = [x>0.5 for x in cpu(yy_train)]        
                    accuracy_train[sp_g,tt] = accuracy_score(np.array(y_tr).astype(int),np.array(cpu(Y_train_t).squeeze().tolist()))
                   
                    
                    acc_test = accuracy_test[sp_g,tt] 
                    
                    if (True):
                        auc_s[sp_g] = auc_test[sp_g,tt]
                        acc_s[sp_g] = acc_test
                        it_model[sp_g] = e

                        mapp = imp_map[0].data.numpy()
                        
                        ACC_test  = accuracy_test[sp_g,tt]
                        ACC_cv  = accuracy_cv[sp_g,tt]
                        ACC_train = accuracy_train[sp_g,tt]
                        AUC_test  = auc_test[sp_g,tt]
                        AUC_train = auc_train[sp_g,tt]   
                        AUC_cv = auc_cv[sp_g,tt] 
                        
                        class_test_pred = cpu(y_test).data.numpy()    
                        class_test = cpu(Y_test).data.numpy() 

                        
                        
                        class_train_pred = cpu(yy_train).data.numpy()    
                        class_train = cpu(Y_train_t).data.numpy()

                        torch.save(net.state_dict(), 'pretrained'+str(id_fold+1)+'/models/model'+str(e)+'.pth')
                        vis = {'gen_pred':mapp,'acc_train':ACC_train,'acc_test': ACC_test,'auc_test':AUC_test,'auc_train':AUC_train,'class_test_pred':class_test_pred, \
                               'class_test':class_test, 'class_train': class_train, 'class_train_pred':class_train_pred,'loss':ll_test.data.numpy()}   
                        sio.savemat('pretrained'+str(id_fold+1)+'/setting/model'+str(f)+'/vis'+str(e)+'.mat',vis)                
                
                
                    if e%num_f == 0 :
                    
                        plt.plot(e_losses_test[sp_g,0:tt].squeeze().tolist())
                        plt.savefig('pretrained'+str(id_fold+1)+'/model'+str(f)+'/cost_test'+'.pdf')
                        plt.close()  
                        plt.plot(e_recon_test[sp_g,0:tt].squeeze().tolist())
                        plt.savefig('pretrained'+str(id_fold+1)+'/model'+str(f)+'/recon_test'+'.pdf')
                        plt.close()
                        plt.plot(e_class_test[sp_g,0:tt].squeeze().tolist())
                        plt.savefig('pretrained'+str(id_fold+1)+'/model'+str(f)+'/class_test'+'.pdf')
                        plt.close()
                        
                        
                        plt.plot(e_losses_test_con[sp_g,0:tt].squeeze().tolist())
                        plt.savefig('pretrained'+str(id_fold+1)+'/model'+str(f)+'/class_test_con'+'.pdf')
                        plt.close()     
                        plt.plot(e_losses_test_dis[sp_g,0:tt].squeeze().tolist())
                        plt.savefig('pretrained'+str(id_fold+1)+'/model'+str(f)+'/class_test_dis'+'.pdf')
                        plt.close()
                        
                        plt.plot(e_recon_test_con[sp_g,0:tt].squeeze().tolist())
                        plt.savefig('pretrained'+str(id_fold+1)+'/model'+str(f)+'/recon_test_con'+'.pdf')
                        plt.close()          
                        plt.plot(e_recon_test_dis[sp_g,0:tt].squeeze().tolist())
                        plt.savefig('pretrained'+str(id_fold+1)+'/model'+str(f)+'/recon_test_dis'+'.pdf')
                        plt.close()
                        
                        plt.plot(e_losses)
                        #plt.show()
                        plt.savefig('pretrained'+str(id_fold+1)+'/model'+str(f)+'/cost_train'+'.pdf')
                        plt.close()
                        
                        plt.plot(e_losses_c)
                        #plt.show()
                        plt.savefig('pretrained'+str(id_fold+1)+'/model'+str(f)+'/class_train_control'+'.pdf')
                        plt.close()
                        
                        plt.plot(e_losses_d)
                        #plt.show()
                        plt.savefig('pretrained'+str(id_fold+1)+'/model'+str(f)+'/class_train_disease'+'.pdf')
                        plt.close()
                    
                        plt.plot(e_losses_class)
                        #plt.show()
                        plt.savefig('pretrained'+str(id_fold+1)+'/model'+str(f)+'/class_train'+'.pdf')
                        plt.close()
                    
                        plt.plot(e_losses_reg)
                        #plt.show()
                        plt.savefig('pretrained'+str(id_fold+1)+'/model'+str(f)+'/recon_train'+'.pdf')
                        plt.close()
                    
                        plt.plot(accuracy_test[sp_g,0:tt].squeeze().tolist())
                        #plt.show()
                        plt.savefig('pretrained'+str(id_fold+1)+'/model'+str(f)+'/accuracy_test'+'.pdf')
                        plt.close()                    
                    
                        plt.plot(accuracy_train[sp_g,0:tt].squeeze().tolist())
                        #plt.show()
                        plt.savefig('pretrained'+str(id_fold+1)+'/model'+str(f)+'/accuracy_train'+'.pdf')
                        plt.close()
                    
                        plt.plot(auc_train[sp_g,0:tt].squeeze().tolist())
                        #plt.show()
                        plt.savefig('pretrained'+str(id_fold+1)+'/model'+str(f)+'/auc_train'+'.pdf')
                        plt.close()
                    
                        plt.plot(auc_test[sp_g,0:tt].squeeze().tolist())
                        #plt.show()
                        plt.savefig('pretrained'+str(id_fold+1)+'/model'+str(f)+'/auc_test'+'.pdf')
                        plt.close()
            tt += 1
        
