
import sys
import torch
import torch.nn as nn
import torch.nn.functional as Fun
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import matplotlib
import math
import pandas as pd
import scipy.io as sio
from sklearn.metrics import accuracy_score
from scipy.io import loadmat
import shutil
import os

from net_train import net_train

from net_test import net_test

from model_random import Gene_ontology_network


from convert_to_gpu import gpu
from convert_to_gpu_and_tensor import gpu_t
from convert_to_gpu_scalar import gpu_ts
from convert_to_cpu import cpu
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from sklearn import metrics
plt.ioff()


id_fold    = 0#math.floor(int(sys.argv[1])%10)

e_losses   = []
l_dim      = 50
num_epochs = 501
num_s      = 0
num_f      = 5
iter_n     = 20
temperature = gpu_ts(0.1)    
lambda0 = gpu_ts(0.00001)
prob_ref = [gpu_ts(0.01)]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

######################################################################################################################################################################################
## Loading data.
gene_train_l1 = loadmat('../../data/train_test_data_gcn_5_layers_joint_model.mat')['gene_train']
gene_train    = gpu(torch.tensor(gene_train_l1).float())

#load validation data.
gene_test_l1 = loadmat('../../data/train_test_data_gcn_5_layers_joint_model.mat')['gene_test']
gene_test    = gpu(torch.tensor(gene_test_l1).float())

# load the adjacency matrix of gene ontology A_p
A = loadmat('../../data/ontology_adjacency.mat')['A']
A = gpu(torch.tensor(A).float().t().to_sparse().coalesce() ) # This operation ensures that if A[i,j]=1, then there is a connection from j --> i

# load the binary matrix of gene scores to pathways A_g
A_g = loadmat('../../data/gene_to_pathways.mat')['mask_gene']
A_g = gpu(torch.tensor(A_g).float().to_sparse().coalesce() )


Y_train = loadmat('../../data/train_test_data_gcn_5_layers_joint_model.mat')['class_train']
Y_test = loadmat('../../data/train_test_data_gcn_5_layers_joint_model.mat')['class_test']

# This is an array which contains the number of nodes discarded after each layer of GCN. For example, pool_dim[0] = 1000 means after 1st GCN the first 0:999 nodes of the graph are discarded
pool_dim = loadmat('../../data/pool_layers.mat')['pool_n']
pool_dim = pool_dim.tolist()

Y_train = gpu(torch.tensor(Y_train).float())
Y_test = gpu(torch.tensor(Y_test).float())

############################################################################################################################################################################################################################################################################################################################################################################

# Thresholdin probabilities.
sps_g = [0] 


# Batch Size
batch_size = 113


# Create folder to store models and results
if( not os.path.exists('pretrained') ):
    os.mkdir('pretrained')
    
if( not os.path.exists('pretrained/setting') ):
    os.mkdir('pretrained/setting') # Create cross validation folder f.
  
# Folder to store results. 
for f in range(id_fold, id_fold+1):
    if( os.path.exists('pretrained/model'+str(f)) ):
        shutil.rmtree('pretrained/model'+str(f))
    
    
    if( os.path.exists('pretrained/models')):
        shutil.rmtree('pretrained/models')
        
    print('fold - '+str(f))
    os.mkdir('pretrained/model'+str(f)) # Create cross validation folder f.
    os.mkdir('pretrained/models')
    
    if  os.path.exists('pretrained/setting/model'+str(f)) :
        shutil.rmtree('pretrained/setting/model'+str(f))
        
    os.mkdir('pretrained/setting/model'+str(f)) # Create cross valida+ 0.1*loss2tion folder f.
 
    # network dimensions
    net = Gene_ontology_network(A_g, A, 2, 4, [5,5,5,5], pool_dim,l_dim)
    net = net.to(device)
    
    # Initialize optimizers
  
    opt = optim.Adam(net.parameters(), lr=0.0001, betas=(0.9, 0.999),weight_decay=0)  
    scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=50, gamma = 0.7) # this will decrease the learning rate by factor of 0.1

    # Define the loss functions
    criterion_class = nn.BCELoss(reduction='none')
    criterion_recon = nn.MSELoss(reduction='none')
    e_losses = [] # total loss
    e_losses_c = [] # control loss
    e_losses_d = [] # disease loss
    e_losses_reg = [] # regularization loss
    e_losses_class =[] # class loss
     
    e_losses_test = np.zeros((len(sps_g),num_epochs))
    e_recon_test = np.zeros((len(sps_g),num_epochs))
    e_class_test = np.zeros((len(sps_g),num_epochs))
    e_losses_test_con = np.zeros((len(sps_g),num_epochs))
    e_losses_test_dis = np.zeros((len(sps_g),num_epochs))
    e_recon_test_con = np.zeros((len(sps_g),num_epochs))
    e_recon_test_dis = np.zeros((len(sps_g),num_epochs))

    acc_s     = 0*np.ones(len(sps_g))
    auc_s     = 0*np.ones(len(sps_g))
    it_model     = np.zeros(len(sps_g))
    accuracy_test  = np.zeros((len(sps_g),3000))
    accuracy_cv  = np.zeros((len(sps_g),3000))
    auc_test  = np.zeros((len(sps_g),3000))
    auc_cv  = np.zeros((len(sps_g),3000))
    auc_train  = np.zeros((len(sps_g),3000))
    accuracy_train  = np.zeros((len(sps_g),3000))
    tt = 0
    interv= 1
    for e in range(num_epochs):
        print(e)
        ll, l_class, l_reg, ll_c, ll_d, yy_train, Y_train_t = net_train(net, gene_train, Y_train, opt, batch_size, temperature, lambda0, prob_ref, criterion_class,criterion_recon, e,iter_n)
        
        e_losses_c.append(ll_c)
        e_losses_d.append(ll_d)
        e_losses.append(ll)
        e_losses_reg.append(l_reg)
        e_losses_class.append(l_class)
        
        scheduler.step()
        if e%interv == 0:
            
            for sp_g in range(len(sps_g)):

                ll_test, y_test, imp_map, ll_class, ll_test_con,ll_test_dis, ll_recon, ll_recon_con, ll_recon_dis = net_test(net, gene_test, Y_test, temperature, criterion_class, criterion_recon, lambda0)

                
                e_losses_test[sp_g,tt] = ll_test
                e_recon_test[sp_g,tt] = ll_recon
                e_class_test[sp_g,tt] = ll_class
                
                e_losses_test_con[sp_g,tt] = ll_test_con
                e_losses_test_dis[sp_g,tt] = ll_test_dis
                e_recon_test_con[sp_g,tt] = ll_recon_con
                e_recon_test_dis[sp_g,tt] = ll_recon_dis

                
            
                
                if e>= num_s and e%interv == 0:
                    
                    # AUC calculation
                    fpr_test, tpr_test, thresholds_test    = metrics.roc_curve( cpu(Y_test).data.numpy()  , cpu(y_test).data.numpy(), pos_label=1)  
 
                    fpr_train, tpr_train, thresholds_train = metrics.roc_curve( cpu(Y_train_t).data.numpy() , cpu(yy_train).data.numpy(), pos_label=1)    
                
                
                    auc_test[sp_g,tt]  =  metrics.auc(fpr_test, tpr_test)
                    
                
                    auc_train[sp_g,tt] =  metrics.auc(fpr_train, tpr_train)
                    
                    # Accuracy calculation
                    
                    y_te = [x>0.5 for x in cpu(y_test)]        
                    accuracy_test[sp_g,tt] = accuracy_score(np.array(y_te).astype(int),np.array(cpu(Y_test).squeeze().tolist()))
                
                    y_tr = [x>0.5 for x in cpu(yy_train)]        
                    accuracy_train[sp_g,tt] = accuracy_score(np.array(y_tr).astype(int),np.array(cpu(Y_train_t).squeeze().tolist()))
                                    
                    
                    acc_test = accuracy_test[sp_g,tt] 
                    
                    if (True):
                        auc_s[sp_g] = auc_test[sp_g,tt]
                        acc_s[sp_g] = acc_test
                        it_model[sp_g] = e
                        mapp = imp_map[0].data.numpy()
                        
                        ACC_test  = accuracy_test[sp_g,tt]
                        ACC_cv  = accuracy_cv[sp_g,tt]
                        ACC_train = accuracy_train[sp_g,tt]
                        AUC_test  = auc_test[sp_g,tt]
                        AUC_train = auc_train[sp_g,tt]   
                        AUC_cv = auc_cv[sp_g,tt] 
                        
                        class_test_pred = cpu(y_test).data.numpy()    
                        class_test = cpu(Y_test).data.numpy() 
                        

                        
                        
                        class_train_pred = cpu(yy_train).data.numpy()    
                        class_train = cpu(Y_train_t).data.numpy()
                         
                        # Save model parameters
                        
                        torch.save(net.state_dict(), 'pretrained/models/model'+str(e)+'.pth')
                        vis = {'gen_pred':mapp,'acc_train':ACC_train,'acc_test': ACC_test,'auc_test':AUC_test,'auc_train':AUC_train,'class_test_pred':class_test_pred, \
                               'class_test':class_test, 'class_train': class_train, 'class_train_pred':class_train_pred,'loss':ll_test.data.numpy()}   
                        sio.savemat('pretrained/setting/model'+str(f)+'/vis'+str(e)+'.mat',vis)                
                
                    # Plot loss characteristics.
                    
                    if e%num_f == 0 :#and e!=0:
                    
                        plt.plot(e_losses_test[sp_g,0:tt].squeeze().tolist())
                        plt.savefig('pretrained/model'+str(f)+'/cost_test'+'.pdf')
                        plt.close()  
                        plt.plot(e_recon_test[sp_g,0:tt].squeeze().tolist())
                        plt.savefig('pretrained/model'+str(f)+'/recon_test'+'.pdf')
                        plt.close()
                        plt.plot(e_class_test[sp_g,0:tt].squeeze().tolist())
                        plt.savefig('pretrained/model'+str(f)+'/class_test'+'.pdf')
                        plt.close()
                        
                        
                        plt.plot(e_losses_test_con[sp_g,0:tt].squeeze().tolist())
                        plt.savefig('pretrained/model'+str(f)+'/class_test_con'+'.pdf')
                        plt.close()     
                        plt.plot(e_losses_test_dis[sp_g,0:tt].squeeze().tolist())
                        plt.savefig('pretrained/model'+str(f)+'/class_test_dis'+'.pdf')
                        plt.close()
                        
                        plt.plot(e_recon_test_con[sp_g,0:tt].squeeze().tolist())
                        plt.savefig('pretrained/model'+str(f)+'/recon_test_con'+'.pdf')
                        plt.close()          
                        plt.plot(e_recon_test_dis[sp_g,0:tt].squeeze().tolist())
                        plt.savefig('pretrained/model'+str(f)+'/recon_test_dis'+'.pdf')
                        plt.close()
                        
                        
                        

                    
                        plt.plot(e_losses)
                        #plt.show()
                        plt.savefig('pretrained/model'+str(f)+'/cost_train'+'.pdf')
                        plt.close()
                        
                        plt.plot(e_losses_c)
                        #plt.show()
                        plt.savefig('pretrained/model'+str(f)+'/class_train_control'+'.pdf')
                        plt.close()
                        
                        plt.plot(e_losses_d)
                        #plt.show()
                        plt.savefig('pretrained/model'+str(f)+'/class_train_disease'+'.pdf')
                        plt.close()
                    
                        plt.plot(e_losses_class)
                        #plt.show()
                        plt.savefig('pretrained/model'+str(f)+'/class_train'+'.pdf')
                        plt.close()
                    
                        plt.plot(e_losses_reg)
                        #plt.show()
                        plt.savefig('pretrained/model'+str(f)+'/recon_train'+'.pdf')
                        plt.close()
                    
                        plt.plot(accuracy_test[sp_g,0:tt].squeeze().tolist())
                        #plt.show()
                        plt.savefig('pretrained/model'+str(f)+'/accuracy_test'+'.pdf')
                        plt.close()
                        
                    
                        plt.plot(accuracy_train[sp_g,0:tt].squeeze().tolist())
                        #plt.show()
                        plt.savefig('pretrained/model'+str(f)+'/accuracy_train'+'.pdf')
                        plt.close()
                    
                        plt.plot(auc_train[sp_g,0:tt].squeeze().tolist())
                        #plt.show()
                        plt.savefig('pretrained/model'+str(f)+'/auc_train'+'.pdf')
                        plt.close()
                    
                        plt.plot(auc_test[sp_g,0:tt].squeeze().tolist())
                        #plt.show()
                        plt.savefig('pretrained/model'+str(f)+'/auc_test'+'.pdf')
                        plt.close()
                                                          
                    
            tt += 1
        
