import numpy as np
import copy
import matplotlib.pyplot as plt
import seaborn as sns
import h5py
import torch
import torch.nn as nn
import torch.nn.functional
import torch.utils.data as data
import numpy as np
import sys
import scnn.scnn
import scnn.chebyshev
import time
import dgl
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import average_precision_score
from sklearn.metrics import auc
from sklearn.model_selection import KFold
from sklearn.model_selection import StratifiedKFold
import itertools
from sklearn.metrics import accuracy_score
import pickle
import json 
from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()
def save_variable(variable,filename):
  pickle.dump(variable,open(filename, "wb"))
def load_variable(filename):
  return pickle.load(open(filename,'rb')) 
#from random import sample
#from random import choices
sys.path.append('.')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('device being used:', device)

dataset = 'proteins' #proteins/imdb_b/imdb_m/nci1/reddit_b/reddit_m
path = '/data/gc/'
training_graphs = np.load(path+dataset+'/training_graphs_concat_.npy',allow_pickle=True)
training_labels = np.load(path++dataset+'/training_labels_concat_.npy',allow_pickle=True)

val_boundaries = np.load(path+dataset+'/val_boundaries_concat_.npy',allow_pickle=True)
val_graphs = np.load(path+dataset+'/val_graphs_concat_.npy',allow_pickle=True)
val_labels = np.load(path+dataset+'/val_labels_concat_.npy',allow_pickle=True)

testing_boundaries = np.load(path+dataset+'/testing_boundaries_concat_.npy',allow_pickle=True)
testing_graphs = np.load(path+dataset+'/testing_graphs_concat_.npy',allow_pickle=True)
testing_labels = np.load(path+dataset+'/testing_labels_concat_.npy',allow_pickle=True)

x0_0_tr = np.load(path+dataset+'/x0_0_tr_concat_.npy',allow_pickle=True)
x0_1_tr = np.load(path+dataset+'/x0_1_tr_concat_.npy',allow_pickle=True)
x0_2_tr = np.load(path+dataset+'/x0_2_tr_concat_.npy',allow_pickle=True)
x1_0_tr = np.load(path+dataset+'/x1_0_tr_concat_.npy',allow_pickle=True)
x1_1_tr = np.load(path+dataset+'/x1_1_tr_concat_.npy',allow_pickle=True)
x1_2_tr = np.load(path+dataset+'/x1_2_tr_concat_.npy',allow_pickle=True)
x2_0_tr = np.load(path+dataset+'/x2_0_tr_concat_.npy',allow_pickle=True)
x2_1_tr = np.load(path+dataset+'/x2_1_tr_concat_.npy',allow_pickle=True)
x2_2_tr = np.load(path+dataset+'/x2_2_tr_concat_.npy',allow_pickle=True)

x0_0_val = np.load(path+dataset+'/x0_0_val_concat_.npy',allow_pickle=True)
x0_1_val = np.load(path+dataset+'/x0_1_val_concat_.npy',allow_pickle=True)
x0_2_val = np.load(path+dataset+'/x0_2_val_concat_.npy',allow_pickle=True)
x1_0_val = np.load(path+dataset+'/x1_0_val_concat_.npy',allow_pickle=True)
x1_1_val = np.load(path+dataset+'/x1_1_val_concat_.npy',allow_pickle=True)
x1_2_val = np.load(path+dataset+'/x1_2_val_concat_.npy',allow_pickle=True)
x2_0_val = np.load(path+dataset+'/x2_0_val_concat_.npy',allow_pickle=True)
x2_1_val = np.load(path+dataset+'/x2_1_val_concat_.npy',allow_pickle=True)
x2_2_val = np.load(path+dataset+'/x2_2_val_concat_.npy',allow_pickle=True)

x0_0_test = np.load(path+dataset+'/x0_0_test_concat_.npy',allow_pickle=True)
x0_1_test = np.load(path+dataset+'/x0_1_test_concat_.npy',allow_pickle=True)
x0_2_test = np.load(path+dataset+'/x0_2_test_concat_.npy',allow_pickle=True)
x1_0_test = np.load(path+dataset+'/x1_0_test_concat_.npy',allow_pickle=True)
x1_1_test = np.load(path+dataset+'/x1_1_test_concat_.npy',allow_pickle=True)
x1_2_test = np.load(path+dataset+'/x1_2_test_concat_.npy',allow_pickle=True)
x2_0_test = np.load(path+dataset+'/x2_0_test_concat_.npy',allow_pickle=True)
x2_1_test = np.load(path+dataset+'/x2_1_test_concat_.npy',allow_pickle=True)
x2_2_test = np.load(path+dataset+'/x2_2_test_concat_.npy',allow_pickle=True)

L_relu = nn.LeakyReLU()
sig = nn.Sigmoid()
relu = nn.ReLU(inplace=False)
tanh = nn.Tanh()

class SA_MLP(nn.Module):
    def __init__(self,d1,d2,d3,d4,d5,d6,d7,d8,n_c):
        super(SA_MLP,self).__init__()

        # Simplices of dimension 0.
        self.g0_0 = nn.Sequential(nn.Linear(d1,d2),L_relu,nn.Linear(d2,d3),L_relu,nn.Linear(d3,d3),L_relu,nn.Linear(d3,d3),L_relu)
        self.g0_1 = nn.Sequential(nn.Linear(6,d2),L_relu,nn.Linear(d2,d3),L_relu,nn.Linear(d3,d3),L_relu,nn.Linear(d3,d3),L_relu)
        self.g0_2 = nn.Sequential(nn.Linear(18,d2),L_relu,nn.Linear(d2,d3),L_relu,nn.Linear(d3,d3),L_relu,nn.Linear(d3,d3),L_relu)

        # Simplices of dimension 1.
        self.g1_0 = nn.Sequential(nn.Linear(d1,d2),L_relu,nn.Linear(d2,d3),L_relu,nn.Linear(d3,d3),L_relu,nn.Linear(d3,d3),L_relu)
        self.g1_1 = nn.Sequential(nn.Linear(12,d2),L_relu,nn.Linear(d2,d3),L_relu,nn.Linear(d3,d3),L_relu,nn.Linear(d3,d3),L_relu)
        self.g1_2 = nn.Sequential(nn.Linear(39,d2),L_relu,nn.Linear(d2,d3),L_relu,nn.Linear(d3,d3),L_relu,nn.Linear(d3,d3),L_relu)
        
        # Simplices of dimension 2.
        self.g2_0 = nn.Sequential(nn.Linear(d1,d2),L_relu,nn.Linear(d2,d3),L_relu,nn.Linear(d3,d3),L_relu,nn.Linear(d3,d3),L_relu)
        self.g2_1 = nn.Sequential(nn.Linear(9,d2),L_relu,nn.Linear(d2,d3),L_relu,nn.Linear(d3,d3),L_relu,nn.Linear(d3,d3),L_relu)
        self.g2_2 = nn.Sequential(nn.Linear(30,d2),L_relu,nn.Linear(d2,d3),L_relu,nn.Linear(d3,d3),L_relu,nn.Linear(d3,d3),L_relu)

        #Decoder.
        self.D = nn.Sequential(nn.Linear(3*3*d3,d8),L_relu,nn.Linear(d8,d8),L_relu,nn.Linear(d8,d8),L_relu,nn.Linear(d8,n_c),sig) #nn.Softmax(dim=0) for multi-class
        self.dropout = nn.Dropout(0.00)

	
    def forward(self, x0_0, x0_1, x0_2, x1_0, x1_1, x1_2, x2_0, x2_1, x2_2):
        out0_1 = self.g0_0(x0_0) 
        out0_2 = self.g0_1(x0_1)
        out0_3 = self.g0_2(x0_2) 
        out1_1 = self.g1_0(x1_0) 
        out1_2 = self.g1_1(x1_1) 
        out1_3 = self.g1_2(x1_2)
        out2_1 = self.g2_0(x2_0) 
        out2_2 = self.g2_1(x2_1) 
        out2_3 = self.g2_2(x2_2)
        
        
        xi_in0 = torch.cat((torch.sum((out0_1),0),torch.sum((out0_2),0),torch.sum((out0_3),0)),0)
        xi_in1 = torch.cat((torch.sum((out1_1),0),torch.sum((out1_2),0),torch.sum((out1_3),0)),0)
        xi_in2 = torch.cat((torch.sum((out2_1),0),torch.sum((out2_2),0),torch.sum((out2_3),0)),0)
        
        phi_in = torch.cat(((xi_in0),(xi_in1),(xi_in2)))
        final_out = self.D(phi_in) 
        return final_out


foldwise_training_loss, foldwise_test_loss,foldwise_training_auc_score,foldwise_val_auc_score,foldwise_test_auc_score  = [],[],[],[],[]
foldwise_training_auc_pr = []
foldwise_test_auc_pr= []
foldwise_val_auc_pr =[]
foldwise_val_loss = [] 
foldwise_training_acc = []
foldwise_test_acc = []
foldwise_val_acc = []
for z in range(10): #folds (default = 10)
	epochwise_training_loss_3, epochwise_test_loss_3,epochwise_training_auc_score_3,  epochwise_val_auc_score_3,epochwise_test_auc_score_3  = [],[],[],[],[]
	epochwise_training_auc_pr_3 = []
	epochwise_test_auc_pr_3= []
	epochwise_val_auc_pr_3 =[]
	epochwise_val_loss_3 = []  
	epochwise_training_acc_3 = []
	epochwise_test_acc_3 = []
	epochwise_val_acc_3 =[]
	for w in range(3): #to remove initialization bias (default = 3)
	  network = SA_MLP(d1=3,d2=2*32,d3=2*32,d4=2*32,d5=2*32,d6=2*32,d7=2*32,d8=2*32,n_c=1).to(device) 
	  #Tune hyperparameters, n_c = 1/1/3/1/1/5
	  learning_rate = 0.001
	  optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate, weight_decay=1e-4)
	  criterion = nn.BCELoss() # binary
	  #criterion = nn.CrossEntropyLoss() #multi-class
	  start = time.time()
	  epochwise_training_loss, epochwise_test_loss,epochwise_training_auc_score,epochwise_val_auc_score,epochwise_test_auc_score  = [],[],[],[],[]
	  epochwise_training_auc_pr = []
	  epochwise_test_auc_pr= []
	  epochwise_val_auc_pr =[]
	  epochwise_val_loss = []   
	  epochwise_training_acc = []
	  epochwise_test_acc = []
	  epochwise_val_acc =[]
	  for i in range(0, 200): #epochs
	    batchwise_training_loss, batchwise_test_loss,batchwise_training_auc_score,batchwise_val_auc_score,batchwise_test_auc_score  = [],[],[],[],[]
	    batchwise_training_auc_pr = []
	    batchwise_test_auc_pr= []
	    batchwise_val_auc_pr =[]
	    batchwise_val_loss = [] 
	    batchwise_training_acc = []
	    batchwise_test_acc = []
	    batchwise_val_acc =[]
	    for j in range(0,len(training_labels[z])//64):
	    	  network.train()
	    	  indices = np.random.choice(range(len(training_labels[z])),64,replace=False)
	    	  ys = torch.Tensor([]).to(device)	
	    	  for jj in range(64):
	    	  	j_ = indices[jj]
	    	  	
	    	  	x0_0 = x0_0_tr[z][j_]
	    	  	x0_1 = x0_1_tr[z][j_]
	    	  	x0_2 = x0_2_tr[z][j_]
	    	  	x1_0 = x1_0_tr[z][j_]
	    	  	x1_1 = x1_1_tr[z][j_]
	    	  	x1_2 = x1_2_tr[z][j_]
	    	  	x2_0 = x2_0_tr[z][j_]
	    	  	x2_1 = x2_1_tr[z][j_]
	    	  	x2_2 = x2_2_tr[z][j_]
	    	  	optimizer.zero_grad() 
	    	  	ys = torch.cat((ys,network(torch.tensor(x0_0).type(torch.FloatTensor).to(device),torch.tensor(x0_1).type(torch.FloatTensor).to(device),
		  	torch.tensor(x0_2).type(torch.FloatTensor).to(device),torch.Tensor(x1_0).type(torch.FloatTensor).to(device),torch.Tensor(x1_1).type(torch.FloatTensor).to(device),
		  	torch.Tensor(x1_2).type(torch.FloatTensor).to(device),torch.Tensor(x2_0).type(torch.FloatTensor).to(device),torch.Tensor(x2_1).type(torch.FloatTensor).to(device),
		  	torch.Tensor(x2_2).type(torch.FloatTensor).to(device))), 0)  			
	    	  roc_auc_score_tr = roc_auc_score(torch.Tensor(np.array([training_labels[z]])[0,indices]).type(torch.FloatTensor),torch.squeeze(ys).type(torch.FloatTensor).detach().numpy())
	    	  ys_ = [1 if (aa>0.5) else 0 for aa in ys[:]]
	    	  acc_tr = accuracy_score(torch.Tensor(np.array([training_labels[z]])[0,indices]).type(torch.FloatTensor),ys_)
	    	  precision_tr, recall_tr, thresholds_tr = precision_recall_curve(torch.Tensor(np.array([training_labels[z]])[0,indices]).type(torch.FloatTensor),torch.squeeze(ys).type(torch.FloatTensor).detach().numpy())
	    	  auc_precision_recall_tr = auc(recall_tr, precision_tr)
	    	  loss = criterion(torch.squeeze(ys).type(torch.FloatTensor),torch.Tensor(np.array([training_labels[z]])[0,indices]).type(torch.FloatTensor))
	    	  batchwise_training_loss.append(loss.item())
	    	  batchwise_training_auc_score.append(roc_auc_score_tr.item())
	    	  batchwise_training_auc_pr.append(auc_precision_recall_tr.item())
	    	  batchwise_training_acc.append(acc_tr.item())
	    	  loss.backward()
	    	  optimizer.step()
	    	  network.eval()
	    	  
	    	  ys_val = torch.Tensor([]).to(device)
	    	  for m in range(0,len(val_boundaries[z])):
	    	  	x0_0 = x0_0_val[z][m]
	    	  	x0_1 = x0_1_val[z][m]
	    	  	x0_2 = x0_2_val[z][m]
	    	  	x1_0 = x1_0_val[z][m]
	    	  	x1_1 = x1_1_val[z][m]
	    	  	x1_2 = x1_2_val[z][m]
	    	  	x2_0 = x2_0_val[z][m]
	    	  	x2_1 = x2_1_val[z][m]
	    	  	x2_2 = x2_2_val[z][m]
	    	  	ys_val = torch.cat((ys_val,network(torch.Tensor(x0_0).type(torch.FloatTensor).to(device),torch.Tensor(x0_1).type(torch.FloatTensor).to(device),
    	torch.Tensor(x0_2).type(torch.FloatTensor).to(device),torch.Tensor(x1_0).to(device),torch.Tensor(x1_1).to(device),
		  	torch.Tensor(x1_2).to(device),torch.Tensor(x2_0).to(device),torch.Tensor(x2_1).to(device),
		  	torch.Tensor(x2_2).to(device))),0)
	    	  l_val = criterion(torch.squeeze(ys_val).type(torch.FloatTensor), torch.tensor(val_labels[z]).type(torch.FloatTensor))
	    	  roc_auc_score_val = roc_auc_score(torch.tensor(val_labels[z]).type(torch.FloatTensor),torch.squeeze(ys_val).type(torch.FloatTensor).detach().numpy())
	    	  ys_val_ = [1 if (xx>0.5) else 0 for xx in ys_val[:]]
	    	  acc_val = accuracy_score(torch.tensor(val_labels[z]).type(torch.FloatTensor),ys_val_)
	    	  precision_val, recall_val, thresholds_val = precision_recall_curve(torch.tensor(val_labels[z]).type(torch.FloatTensor),torch.squeeze(ys_val).type(torch.FloatTensor).detach().numpy())
	    	  auc_precision_recall_val = auc(recall_val, precision_val)
	    	  batchwise_val_loss.append(l_val.item())	  
	    	  batchwise_val_auc_pr.append(auc_precision_recall_val.item())
	    	  batchwise_val_auc_score.append(roc_auc_score_val.item())	
	    	  batchwise_val_acc.append(acc_val.item())	 
	    	  ys_test = torch.Tensor([]).to(device)
	    	  for n in range(0,len(testing_boundaries[z][0])):
	    	  	x0_0 = x0_0_test[z][n]
	    	  	x0_1 = x0_1_test[z][n]
	    	  	x0_2 = x0_2_test[z][n]
	    	  	x1_0 = x1_0_test[z][n]
	    	  	x1_1 = x1_1_test[z][n]
	    	  	x1_2 = x1_2_test[z][n]
	    	  	x2_0 = x2_0_test[z][n]
	    	  	x2_1 = x2_1_test[z][n]
	    	  	x2_2 = x2_2_test[z][n]
	    	  	ys_test = torch.cat((ys_test,network(torch.Tensor(x0_0).type(torch.FloatTensor).to(device),
		  	torch.Tensor(x0_1).type(torch.FloatTensor).to(device),torch.Tensor(x0_2).type(torch.FloatTensor).to(device), 	torch.Tensor(x1_0).to(device),torch.Tensor(x1_1).to(device),torch.Tensor(x1_2).to(device),torch.Tensor(x2_0).to(device),torch.Tensor(x2_1).to(device),torch.Tensor(x2_2).to(device))),0)  	
	    	  l = criterion(torch.squeeze(ys_test).type(torch.FloatTensor), torch.squeeze(torch.tensor(testing_labels[z][0]).type(torch.FloatTensor)))
	    	  roc_auc_score_test = roc_auc_score(torch.squeeze((torch.tensor(testing_labels[z][0]).type(torch.FloatTensor))),torch.squeeze(ys_test).type(torch.FloatTensor).detach().numpy())
	    	  ys_test_ = [1 if (xx>0.5) else 0 for xx in ys_test[:]]
	    	  acc_test = accuracy_score(torch.squeeze((torch.tensor(testing_labels[z][0]).type(torch.FloatTensor))),ys_test_)	    	 
	    	  precision_test, recall_test, thresholds_test = precision_recall_curve(torch.squeeze((torch.tensor(testing_labels[z][0]).type(torch.FloatTensor))),torch.squeeze(ys_test).type(torch.FloatTensor).detach().numpy())
	    	  auc_precision_recall_test = auc(recall_test, precision_test)
	    	  batchwise_test_loss.append(l.item())
	    	  batchwise_test_auc_score.append(roc_auc_score_test)
	    	  batchwise_test_auc_pr.append(auc_precision_recall_test)
	    	  batchwise_test_acc.append(acc_test.item())
	    	  print ("--------------------- | test_acc =%f |"%(acc_test.item()))
	    epochwise_training_loss.append([batchwise_training_loss])
	    epochwise_training_auc_score.append([batchwise_training_auc_score])
	    epochwise_training_auc_pr.append([batchwise_training_auc_pr])
	    epochwise_training_acc.append([batchwise_training_acc])
	    epochwise_val_loss.append([batchwise_val_loss])	  
	    epochwise_val_auc_pr.append([batchwise_val_auc_pr])
	    epochwise_val_auc_score.append([batchwise_val_auc_score])
	    epochwise_val_acc.append([batchwise_val_acc])
	    epochwise_test_loss.append([batchwise_test_loss])
	    epochwise_test_auc_score.append([batchwise_test_auc_score])
	    epochwise_test_auc_pr.append([batchwise_test_auc_pr])
	    epochwise_test_acc.append([batchwise_test_acc])
	    timeit('process')     
	  epochwise_training_loss_3.append([epochwise_training_loss])
	  epochwise_training_auc_score_3.append([epochwise_training_auc_score])
	  epochwise_training_auc_pr_3.append([epochwise_training_auc_pr])
	  epochwise_training_acc_3.append([epochwise_training_acc])
	  epochwise_val_loss_3.append([epochwise_val_loss])	  
	  epochwise_val_auc_pr_3.append([epochwise_val_auc_pr])
	  epochwise_val_auc_score_3.append([epochwise_val_auc_score])
	  epochwise_val_acc_3.append([epochwise_val_acc])
	  epochwise_test_loss_3.append([epochwise_test_loss])
	  epochwise_test_auc_score_3.append([epochwise_test_auc_score])
	  epochwise_test_auc_pr_3.append([epochwise_test_auc_pr])    
	  epochwise_test_acc_3.append([epochwise_test_acc])    
	foldwise_training_loss.append([epochwise_training_loss_3])
	foldwise_training_auc_score.append([epochwise_training_auc_score_3])
	foldwise_training_auc_pr.append([epochwise_training_auc_pr_3])
	foldwise_training_acc.append([epochwise_training_acc_3])
	foldwise_val_loss.append([epochwise_val_loss_3])	  
	foldwise_val_auc_pr.append([epochwise_val_auc_pr_3])
	foldwise_val_auc_score.append([epochwise_val_auc_score_3])
	foldwise_val_acc.append([epochwise_val_acc_3])
	foldwise_test_loss.append([epochwise_test_loss_3])
	foldwise_test_auc_score.append([epochwise_test_auc_score_3])
	foldwise_test_auc_pr.append([epochwise_test_auc_pr_3])  
	foldwise_test_acc.append([epochwise_test_acc_3])   
