
import sys
import torch
import torch.nn as nn
import torch.nn.functional as Fun
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import matplotlib
import math
import pandas as pd
import scipy.io as sio
from sklearn.metrics import accuracy_score
from scipy.io import loadmat
import shutil
import os

from net_train import net_train

from net_test import net_test

from model_random import Gene_ontology_network


from convert_to_gpu import gpu
from convert_to_gpu_and_tensor import gpu_t
from convert_to_gpu_scalar import gpu_ts
from convert_to_cpu import cpu
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from sklearn import metrics
plt.ioff()


rand_id   = int(sys.argv[1])

e_losses   = []
l_dim      = 50 # Latent dimension
num_epochs = 501 # max number of epochs
num_s      = 0
num_f      = 5
iter_n     = 20
temperature = gpu_ts(0.1)     
lambda0 = gpu_ts(0.00001)
prob_ref = [gpu_ts(0.01)]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

######################################################################################################################################################################################
## Loading data.
gene_train_l1 = loadmat('../../data/train_test_data_gcn_5_layers_joint_model.mat')['gene_train']
gene_train    = gpu(torch.tensor(gene_train_l1).float())

#load validation data.
gene_test_l1 = loadmat('../../data/train_test_data_gcn_5_layers_joint_model.mat')['gene_test']
gene_test    = gpu(torch.tensor(gene_test_l1).float())

# Here we load our random graphs
A = loadmat('../../data/random/5_layer_random_connection_'+ str(rand_id) +'.mat')['mask_con'] # This is the adjacency matrix for random graph. The rand_id chooses one of many predefined random graphs
A = gpu(torch.tensor(A).float().t().to_sparse().coalesce() ) # This operation ensures that if A[i,j]=1, then there is a connection from j --> i

# load the binary matrix of gene scores to random pathways, A_g
A_g = loadmat('../../data/random/gene_pathway_layer_connection_and_input_data_5_random_' + str(rand_id) + '.mat')['mask_gene_encode']
A_g = gpu(torch.tensor(A_g).float().to_sparse().coalesce() )


Y_train = loadmat('../../data/train_test_data_gcn_5_layers_joint_model.mat')['class_train']
Y_test = loadmat('../../data/train_test_data_gcn_5_layers_joint_model.mat')['class_test']
#
pool_dim = loadmat('../../data/train_test_data_gcn_5_layers_joint_model.mat')['pool_n']
pool_dim = pool_dim.tolist()

Y_train = gpu(torch.tensor(Y_train).float())
Y_test = gpu(torch.tensor(Y_test).float())

############################################################################################################################################################################################################################################################################################################################################################################

# Thresholdin probabilities.
sps_g = [0] 


# Batch Size
batch_size = 113

# Create folder to store models and results
if( os.path.exists('model'+str(rand_id))):
    shutil.rmtree('model'+str(rand_id))

os.mkdir('model'+str(rand_id))
    
os.mkdir('model'+str(rand_id)+'/setting') # Create cross validation folder f.
    
id_fold = 0
  
# Create folders to store results.  
for f in range(id_fold, id_fold+1):

    print('model - '+str(rand_id))

    # network dimensions
    net = Gene_ontology_network(A_g, A, 2, 4, [5,5,5,5], pool_dim,l_dim)
    net = net.to(device)
    
    # Initialize optimizers    
    opt = optim.Adam(net.parameters(), lr=0.0001, betas=(0.9, 0.999),weight_decay=0)   
    scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=50, gamma = 0.7) # this will decrease the learning rate by factor of 0.1

    # Define the loss functions
    criterion_class = nn.BCELoss(reduction='none')
    criterion_recon = nn.MSELoss(reduction='none')
    e_losses = []
    e_losses_c = []
    e_losses_d = []
    e_losses_reg = []
    e_losses_class =[]
     
    e_losses_test = np.zeros((len(sps_g),num_epochs))
    e_recon_test = np.zeros((len(sps_g),num_epochs))
    e_class_test = np.zeros((len(sps_g),num_epochs))
    e_losses_test_con = np.zeros((len(sps_g),num_epochs))
    e_losses_test_dis = np.zeros((len(sps_g),num_epochs))
    e_recon_test_con = np.zeros((len(sps_g),num_epochs))
    e_recon_test_dis = np.zeros((len(sps_g),num_epochs))

 
    acc_s     = 0*np.ones(len(sps_g))
    auc_s     = 0*np.ones(len(sps_g))
    it_model     = np.zeros(len(sps_g))
    accuracy_test  = np.zeros((len(sps_g),3000))
    accuracy_cv  = np.zeros((len(sps_g),3000))
    auc_test  = np.zeros((len(sps_g),3000))
    auc_cv  = np.zeros((len(sps_g),3000))
    auc_train  = np.zeros((len(sps_g),3000))
    accuracy_train  = np.zeros((len(sps_g),3000))
    tt = 0
    interv= 1
    for e in range(num_epochs):
        print(e)
        ll, l_class, l_reg, ll_c, ll_d, yy_train, Y_train_t = net_train(net, gene_train, Y_train, opt, batch_size, temperature, lambda0, prob_ref, criterion_class,criterion_recon, e,iter_n)
        
        e_losses_c.append(ll_c)
        e_losses_d.append(ll_d)
        e_losses.append(ll)
        e_losses_reg.append(l_reg)
        e_losses_class.append(l_class)
        
        scheduler.step()
        if e%interv == 0:
            
            for sp_g in range(len(sps_g)):

                ll_test, y_test, imp_map, ll_class, ll_test_con,ll_test_dis, ll_recon, ll_recon_con, ll_recon_dis = net_test(net, gene_test, Y_test, temperature, criterion_class, criterion_recon, lambda0)

                e_losses_test[sp_g,tt] = ll_test
                e_recon_test[sp_g,tt] = ll_recon
                e_class_test[sp_g,tt] = ll_class
                
                e_losses_test_con[sp_g,tt] = ll_test_con
                e_losses_test_dis[sp_g,tt] = ll_test_dis
                e_recon_test_con[sp_g,tt] = ll_recon_con
                e_recon_test_dis[sp_g,tt] = ll_recon_dis

                
            
            
                if e>= num_s and e%interv == 0:
                    # AUC Calculation
                    fpr_test, tpr_test, thresholds_test    = metrics.roc_curve( cpu(Y_test).data.numpy()  , cpu(y_test).data.numpy(), pos_label=1)  

                    fpr_train, tpr_train, thresholds_train = metrics.roc_curve( cpu(Y_train_t).data.numpy() , cpu(yy_train).data.numpy(), pos_label=1)    
                
                
                    auc_test[sp_g,tt]  =  metrics.auc(fpr_test, tpr_test)
                    

                
                    auc_train[sp_g,tt] =  metrics.auc(fpr_train, tpr_train)
                
                    y_te = [x>0.5 for x in cpu(y_test)]        
                    accuracy_test[sp_g,tt] = accuracy_score(np.array(y_te).astype(int),np.array(cpu(Y_test).squeeze().tolist()))
                
                    y_tr = [x>0.5 for x in cpu(yy_train)]        
                    accuracy_train[sp_g,tt] = accuracy_score(np.array(y_tr).astype(int),np.array(cpu(Y_train_t).squeeze().tolist()))
                  
                    
                    acc_test = accuracy_test[sp_g,tt] 
                    
                    if (True):
                        auc_s[sp_g] = auc_test[sp_g,tt]
                        acc_s[sp_g] = acc_test
                        it_model[sp_g] = e
                        mapp = imp_map[0].data.numpy()
                        
                        ACC_test  = accuracy_test[sp_g,tt]
                        ACC_cv  = accuracy_cv[sp_g,tt]
                        ACC_train = accuracy_train[sp_g,tt]
                        AUC_test  = auc_test[sp_g,tt]
                        AUC_train = auc_train[sp_g,tt]   
                        AUC_cv = auc_cv[sp_g,tt] 
                        
                        class_test_pred = cpu(y_test).data.numpy()    
                        class_test = cpu(Y_test).data.numpy() 
                        

                        
                        class_train_pred = cpu(yy_train).data.numpy()    
                        class_train = cpu(Y_train_t).data.numpy()
                         
                        # Save model parameters
                        
                        torch.save(net.state_dict(), 'model'+str(rand_id)+'/setting/model'+str(e)+'.pth')
                        vis = {'gen_pred':mapp,'acc_train':ACC_train,'acc_test': ACC_test,'auc_test':AUC_test,'auc_train':AUC_train,'class_test_pred':class_test_pred, \
                               'class_test':class_test, 'class_train': class_train, 'class_train_pred':class_train_pred,'loss':ll_test.data.numpy()}   
                        sio.savemat('model'+str(rand_id)+'/vis'+str(e)+'.mat',vis)                
                
                
                    if e%num_f == 0 :#and e!=0:
                    
                        plt.plot(e_losses_test[sp_g,0:tt].squeeze().tolist())
                        plt.savefig('model'+str(rand_id)+'/cost_test'+'.pdf')
                        plt.close()  
                        plt.plot(e_recon_test[sp_g,0:tt].squeeze().tolist())
                        plt.savefig('model'+str(rand_id)+'/recon_test'+'.pdf')
                        plt.close()
                        plt.plot(e_class_test[sp_g,0:tt].squeeze().tolist())
                        plt.savefig('model'+str(rand_id)+'/class_test'+'.pdf')
                        plt.close()
                        
                        
                        plt.plot(e_losses_test_con[sp_g,0:tt].squeeze().tolist())
                        plt.savefig('model'+str(rand_id)+'/class_test_con'+'.pdf')
                        plt.close()     
                        plt.plot(e_losses_test_dis[sp_g,0:tt].squeeze().tolist())
                        plt.savefig('model'+str(rand_id)+'/class_test_dis'+'.pdf')
                        plt.close()
                        
                        plt.plot(e_recon_test_con[sp_g,0:tt].squeeze().tolist())
                        plt.savefig('model'+str(rand_id)+'/recon_test_con'+'.pdf')
                        plt.close()          
                        plt.plot(e_recon_test_dis[sp_g,0:tt].squeeze().tolist())
                        plt.savefig('model'+str(rand_id)+'/recon_test_dis'+'.pdf')
                        plt.close()
                        
                        
                        

                    
                        plt.plot(e_losses)
                        #plt.show()
                        plt.savefig('model'+str(rand_id)+'/cost_train'+'.pdf')
                        plt.close()
                        
                        plt.plot(e_losses_c)
                        #plt.show()
                        plt.savefig('model'+str(rand_id)+'/class_train_control'+'.pdf')
                        plt.close()
                        
                        plt.plot(e_losses_d)
                        #plt.show()
                        plt.savefig('model'+str(rand_id)+'/class_train_disease'+'.pdf')
                        plt.close()
                    
                        plt.plot(e_losses_class)
                        #plt.show()
                        plt.savefig('model'+str(rand_id)+'/class_train'+'.pdf')
                        plt.close()
                    
                        plt.plot(e_losses_reg)
                        #plt.show()
                        plt.savefig('model'+str(rand_id)+'/recon_train'+'.pdf')
                        plt.close()
                    
                        plt.plot(accuracy_test[sp_g,0:tt].squeeze().tolist())
                        #plt.show()
                        plt.savefig('model'+str(rand_id)+'/accuracy_test'+'.pdf')
                        plt.close()
                        
#                        plt.plot(accuracy_cv[sp_g,0:tt].squeeze().tolist())
#                        #plt.show()
#                        plt.savefig('model'+str(f)+'/accuracy_cv'+str(sps_g[sp_g])+'.pdf')
#                        plt.close()
                    
                        plt.plot(accuracy_train[sp_g,0:tt].squeeze().tolist())
                        #plt.show()
                        plt.savefig('model'+str(rand_id)+'/accuracy_train'+'.pdf')
                        plt.close()
                    
                        plt.plot(auc_train[sp_g,0:tt].squeeze().tolist())
                        #plt.show()
                        plt.savefig('model'+str(rand_id)+'/auc_train'+'.pdf')
                        plt.close()
                    
                        plt.plot(auc_test[sp_g,0:tt].squeeze().tolist())
                        #plt.show()
                        plt.savefig('model'+str(rand_id)+'/auc_test'+'.pdf')
                        plt.close()
                    
            tt += 1
        
