import sys
import torch
import torch.nn as nn
import torch.nn.functional as Fun
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import matplotlib
import math
import pandas as pd
import scipy.io as sio
from sklearn.metrics import accuracy_score
from scipy.io import loadmat
import shutil
import os
from net_train import net_train
from net_test import net_test
from model_random import Gene_ontology_network
from convert_to_gpu import gpu
from convert_to_gpu_and_tensor import gpu_t
from convert_to_gpu_scalar import gpu_ts
from convert_to_cpu import cpu
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from sklearn import metrics
plt.ioff()


id_fold    = 0#math.floor(int(sys.argv[1])%10)
e_losses   = []
l_dim      = 50
num_epochs = 501
num_s      = 0
num_f      = 5
iter_n     = 20
temperature = gpu_ts(0.1)    
lambda0 = gpu_ts(0.00001)
prob_ref = [gpu_ts(0.01)]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

######################################################################################################################################################################################
#Load 208 imaging-genetics data.
gene_test_l1 = loadmat('../../data/train_test_data_gcn_5_layers_joint_model.mat')['gene_cv']
gene_test    = gpu(torch.tensor(gene_test_l1).float())

Y_train = loadmat('../../data/train_test_data_gcn_5_layers_joint_model.mat')['class_train']
Y_test = loadmat('../../data/train_test_data_gcn_5_layers_joint_model.mat')['class_cv']
#
pool_dim = loadmat('../../data/train_test_data_gcn_5_layers_joint_model.mat')['pool_n']
pool_dim = pool_dim.tolist()
Y_test = gpu(torch.tensor(Y_test).float())

############################################################################################################################################################################################################################################################################################################################################################################

auc_test=[]
acc_test=[]

for id in range(1,11):
    
    #load adjacency for random graphs.
    A = loadmat('../../data/random/5_layer_random_connection_'+ str(id) +'.mat')['mask_con']
    A = gpu(torch.tensor(A).float().t().to_sparse().coalesce() )

    A_g = loadmat('../../data/random/gene_pathway_layer_connection_and_input_data_5_random_' + str(id) + '.mat')['mask_gene_encode']
    A_g = gpu(torch.tensor(A_g).float().to_sparse().coalesce() )
    
    print(id)
    # Load Network
    checkpoint = int(loadmat('checkpoint_random_'+str(id)+'.mat')['ind'])
    net = Gene_ontology_network(A_g, A, 2, 4, [5,5,5,5], pool_dim,l_dim)
    net.load_state_dict(torch.load('model'+str(id)+'/setting/model'+str(checkpoint)+'.pth'))
    net = gpu(net)

    
    # Define the loss functions
    criterion_class = nn.BCELoss(reduction='none')
    criterion_recon = nn.MSELoss(reduction='none')

    ll_test, y_test, imp_map, ll_class, ll_test_con,ll_test_dis, ll_recon, ll_recon_con, ll_recon_dis = net_test(net, gene_test, Y_test, temperature, criterion_class, criterion_recon, lambda0)



    # For each random  we obtain a ROC for the test data.              
    fpr_test, tpr_test, thresholds_test    = metrics.roc_curve( cpu(Y_test).data.numpy()  , cpu(y_test).data.numpy(), pos_label=1)  
    
    v = {'fpr':fpr_test,'tpr':tpr_test}
    sio.savemat('fpr_tpr'+str(id)+'.mat',v)
         
    auc_test.append(metrics.auc(fpr_test, tpr_test))
                    

                
    y_te = [x>0.5 for x in cpu(y_test)]        
    accuracy_test = accuracy_score(np.array(y_te).astype(int),np.array(cpu(Y_test).squeeze().tolist()))
                          
    acc_test.append(accuracy_test)
    
    del A
    del A_g
    del net
    torch.cuda.empty_cache()

# Saving the results.
vis = {'AUC':np.array(auc_test),'ACC':acc_test}
sio.savemat('random_aucs.mat',vis)
                    
        
