
import sys
import torch
import torch.nn as nn
import torch.nn.functional as Fun
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import matplotlib
import math
import pandas as pd
import scipy.io as sio
from sklearn.metrics import accuracy_score
from scipy.io import loadmat
import shutil
import os

from net_train import net_train

from net_test import net_test

from model_random import Gene_ontology_network


from convert_to_gpu import gpu
from convert_to_gpu_and_tensor import gpu_t
from convert_to_gpu_scalar import gpu_ts
from convert_to_cpu import cpu
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from sklearn import metrics
plt.ioff()


id_fold    = 0#math.floor(int(sys.argv[1])%10)

e_losses   = []
l_dim      = 50
num_epochs = 501
num_s      = 0
num_f      = 5
iter_n     = 20
temperature = gpu_ts(0.1)    
lambda0 = gpu_ts(0.00001)
prob_ref = [gpu_ts(0.01)]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

######################################################################################################################################################################################

#load test data.
gene_train_l1 = loadmat('/mnt/sdb1/sayan/pathway_analysis/data/train_test_data_gcn_5_layers_joint_model.mat')['gene_train']
gene_train    = gpu(torch.tensor(gene_train_l1).float())

Y_train = loadmat('/mnt/sdb1/sayan/pathway_analysis/data/train_test_data_gcn_5_layers_joint_model.mat')['class_train']

#
pool_dim = loadmat('/mnt/sdb1/sayan/pathway_analysis/data/train_test_data_gcn_5_layers_joint_model.mat')['pool_n']
pool_dim = pool_dim.tolist()
Y_train = gpu(torch.tensor(Y_train).float())

############################################################################################################################################################################################################################################################################################################################################################################

auc_train=[]
acc_test=[]

for id in range(1,11):
    #load mask.
    A = loadmat('/mnt/sdb1/sayan/pathway_analysis/data/random/5_layer_random_connection_'+ str(id) +'.mat')['mask_con']
    A = gpu(torch.tensor(A).float().t().to_sparse().coalesce() )

    A_g = loadmat('/mnt/sdb1/sayan/pathway_analysis/data/random/gene_pathway_layer_connection_and_input_data_5_random_' + str(id) + '.mat')['mask_gene_encode']
    A_g = gpu(torch.tensor(A_g).float().to_sparse().coalesce() )
    
    print(id)
    # Load Network
    checkpoint = int(loadmat('checkpoint_random_'+str(id)+'.mat')['ind'])
    net = Gene_ontology_network(A_g, A, 2, 4, [5,5,5,5], pool_dim,l_dim)
    net.load_state_dict(torch.load('model'+str(id)+'/setting/model'+str(checkpoint)+'.pth'))
    net = gpu(net)

    
    # Define the loss functions
    criterion_class = nn.BCELoss(reduction='none')
    criterion_recon = nn.MSELoss(reduction='none')
    
    batch_size = 257
    pred = []
    for beg_i in range(0, gene_train.size(0),batch_size):

        ll_test, y_tr, imp_map, ll_class, ll_test_con,ll_test_dis, ll_recon, ll_recon_con, ll_recon_dis = net_test(net, gene_train[beg_i:beg_i + batch_size, :], Y_train[beg_i:beg_i + batch_size, :], temperature, criterion_class, criterion_recon, lambda0)
        pred.append(y_tr)

    y_train = np.concatenate(tuple(pred),0)
                    
    fpr_train, tpr_train, thresholds_test    = metrics.roc_curve( cpu(Y_train).data.numpy()  , y_train, pos_label=1)  
    
    v = {'fpr':fpr_train,'tpr':tpr_train}
    sio.savemat('fpr_tpr_training'+str(id)+'.mat',v)
         
    auc_train.append(metrics.auc(fpr_train, tpr_train))
                    
    
    del A
    del A_g
    del net
    torch.cuda.empty_cache()

vis = {'AUC':np.array(auc_train)}
sio.savemat('random_aucs_train.mat',vis)
                    
        
