import numpy as np
import copy
import matplotlib.pyplot as plt
import seaborn as sns
import h5py
import torch
import torch.nn as nn
import torch.nn.functional
import torch.utils.data as data
import numpy as np
import sys
import scnn.scnn
import scnn.chebyshev
import time
import pickle
import random
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import average_precision_score
from sklearn.metrics import auc
from sklearn.model_selection import KFold
from sklearn.model_selection import StratifiedKFold
def save_variable(variable,filename):
  pickle.dump(variable,open(filename, "wb"))
def load_variable(filename):
  return pickle.load(open(filename,'rb')) 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  

dataset = 'planar' #planar/mesh
path = '/data/tp/'
datapath = path+dataset
X = load_variable(datapath+'edge_attributes_'+dataset)
train_mask =  np.zeros((X.shape[0],))
n_ = int(0.8*X.shape[0])
for i in range(n_): train_mask[random.randint(0,(X.shape[0])-1)]=1
test_mask =  np.zeros((X.shape[0],))
test_mask[list(np.squeeze(np.where(train_mask==0)))]=1
X_tr = X[train_mask!=0]
X_test = X[test_mask!=0]
last_nodes = load_variable(datapath+'last_nodes_'+dataset)
target_nodes = load_variable(datapath+'target_nodes_'+dataset)
Y = load_variable(datapath+'targets_'+dataset)
B1 = load_variable(datapath+'B1_'+dataset)
B2 = load_variable(datapath+'B2_'+dataset)
G = load_variable(datapath+'G_'+dataset)
y_tr = np.array(Y)[train_mask!=0]
y_test = np.array(Y)[test_mask!=0]
y_tr = torch.squeeze(torch.Tensor(np.array([[int(y_tr[i][0])] for i in range(len(y_tr))]))) #planar
y_test = torch.squeeze(torch.Tensor(np.array([[y_test[i][0]] for i in range(len(y_test))]))) #planar

#y_tr = torch.squeeze(torch.Tensor(np.array([[int(y_tr[i][0]-1)] for i in range(len(y_tr))]))) #mesh
#y_test = torch.squeeze(torch.Tensor(np.array([[y_test[i][0]-1] for i in range(len(y_test))]))) #mesh


#ocean
'''
dataset = 'ocean'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
datapath = path+dataset
X = np.load(datapath+'flows_in.npy')
train_mask = np.load(datapath+'train_mask.npy')
print(train_mask)
test_mask = np.load(datapath+'test_mask.npy')
X_tr = X[train_mask!=0]
X_test = X[test_mask!=0]
last_nodes = np.load(datapath+'last_nodes.npy')
target_nodes = np.load(datapath+'target_nodes.npy')
Y = np.load(datapath+'targets.npy')
y_tr = [list(np.squeeze(Y[np.where(train_mask!=0)])[i]) for i in range(len(Y[np.where(train_mask!=0)]))]
y_tr_ = len(y_tr)*[1] # 160 for buoy
print(y_tr)
print('y_tr',np.squeeze(y_tr[0]))
for i in range(len(y_tr)): 
  if y_tr[i] == [1.0, 0.0, 0.0, 0.0, 0.0, 0.0]: y_tr_[i] = [0]
  elif y_tr[i] == [0.0, 1.0, 0.0, 0.0, 0.0, 0.0]: y_tr_[i] = [1]
  elif y_tr[i] == [0.0, 0.0, 1.0, 0.0, 0.0, 0.0]: y_tr_[i] = [2]
  elif y_tr[i] == [0.0, 0.0, 0.0, 1.0, 0.0, 0.0]: y_tr_[i] = [3]
  elif y_tr[i] == [0.0, 0.0, 0.0, 0.0, 1.0, 0.0]: y_tr_[i] = [4]
  elif y_tr[i] == [0.0, 0.0, 0.0, 0.0, 0.0, 1.0]: y_tr_[i] = [5]
  
#print(y_tr_[0])  
y_tr = torch.squeeze(torch.Tensor(np.array([[int(y_tr_[i][0])] for i in range(len(y_tr_))])))
y_test = [list(np.squeeze(Y[test_mask!=0])[i]) for i in range(len(Y[test_mask!=0]))]
y_test_ = len(y_test)*[1] 
for i in range(len(y_test)): 
  if y_test[i] == [1.0, 0.0, 0.0, 0.0, 0.0, 0.0]: y_test_[i] = [0] #[1.0, 0.0, 0.0, 0.0, 0.0, 0.0] for buoy
  elif y_test[i] == [0.0, 1.0, 0.0, 0.0, 0.0, 0.0]: y_test_[i] = [1]
  elif y_test[i] == [0.0, 0.0, 1.0, 0.0, 0.0, 0.0]: y_test_[i] = [2]
  elif y_test[i] == [0.0, 0.0, 0.0, 1.0, 0.0, 0.0]: y_test_[i] = [3]
  elif y_test[i] == [0.0, 0.0, 0.0, 0.0, 1.0, 0.0]: y_test_[i] = [4]
  elif y_test[i] == [0.0, 0.0, 0.0, 0.0, 0.0, 1.0]: y_test_[i] = [5]
  
y_test = torch.squeeze(torch.Tensor(np.array([[y_test_[i][0]] for i in range(len(y_test_))])))
B1 = np.load(datapath+'B1.npy')
B2 = np.load(datapath+'B2.npy')
G = load_variable(datapath+'G_undir.pkl')
'''


N0 = (abs(B1@B1.T).shape)[0]
N1 = (abs(B2@B2.T).shape)[0]
N2 = (abs(B2.T@B2).shape)[0]
x1_0 = np.squeeze(X_tr)
x1_1 = x1_0@(B2@B2.T) + x1_0@(B1.T@B1) 
x1_2 = x1_1@(B2@B2.T) + x1_1@(B1.T@B1)

x1_0_test = np.squeeze(X_test)
x1_1_test = x1_0_test@(B2@B2.T) + x1_0_test@(B1.T@B1) 
x1_2_test = x1_1_test@(B2@B2.T) + x1_1_test@(B1.T@B1)

Z_ = []     
for l in range(len(last_nodes)):  
    i = last_nodes[l]
    Z__ = np.zeros((B1.shape[0]))
    Z__[[int(j) for j in G.neighbors(i)]]=1
    Z_.append(list(Z__))
Z_ = np.array(Z_)
Z_tr_ = Z_[train_mask!=0]
Z_test = Z_[test_mask!=0]

L_relu = nn.LeakyReLU()
sig = nn.Sigmoid()
relu = nn.ReLU(inplace=False)
tanh = nn.Tanh()
softmax = nn.Softmax(dim=0)

class SA_MLP(nn.Module):
    def __init__(self,d1,d2,d3,d4,d5,d6):
        super(SA_MLP,self).__init__()
        
        # Simplices of dimension 1.
        self.g1_0 = nn.Sequential(nn.Linear(d1,d2),tanh,nn.Linear(d2,d2),tanh, nn.Linear(d2,d2), tanh, nn.Linear(d2,d3),tanh)
        self.g1_1 = nn.Sequential(nn.Linear(d1,d2),tanh,nn.Linear(d2,d2),tanh, nn.Linear(d2,d2), tanh, nn.Linear(d2,d3),tanh)
        self.g1_2 = nn.Sequential(nn.Linear(d1,d2),tanh,nn.Linear(d2,d2),tanh, nn.Linear(d2,d2), tanh, nn.Linear(d2,d3),tanh)

        self.D = nn.Sequential(nn.Linear(3*d5,d5),tanh,nn.Linear(d5,d5),tanh, nn.Linear(d5,d5),tanh, nn.Linear(d5,d6),softmax)
	
    def forward(self, x1_0, x1_1, x1_2, B1, Z_):

        out1_1 = self.g1_0(x1_0) 
        out1_2 = self.g1_1(x1_1) 
        out1_3 = self.g1_2(x1_2)
        
        xi_in0 = out1_1@B1.T 
        xi_in1 = out1_2@B1.T
        xi_in2 = out1_3@B1.T
        
        xi_out0 = self.xi0(xi_in0)
        xi_out1 = self.xi1(xi_in1)
        xi_out2 = self.xi1(xi_in2)
        xi_out = torch.cat((xi_out0*Z_.to(device),xi_out1*Z_.to(device),xi_out2*Z_.to(device)),1)
        final_out = self.D(xi_out.to(device))     				       
        return final_out


indices_all = np.array(list(range(len(y_tr))))
np.random.seed(1)
kf = StratifiedKFold(n_splits=5)
kf.get_n_splits(indices_all,y_tr)
foldwise_training_loss = [] 
foldwise_val_loss = [] 
foldwise_test_loss = [] 
foldwise_training_acc = []
foldwise_test_acc = []
foldwise_val_acc = []
toprint = []
def evaluate(logits,labels):
    pred_train = [torch.argmax(logits[j]).item() for j in range(len(logits))]
    return accuracy_score(labels,pred_train)

for train_index, test_index in kf.split(indices_all,y_tr):
  test_auc_pr_kfold = []
  network = SA_MLP(d1=(X_tr.shape)[1],d2=(X_tr.shape)[1],d3=(X_tr.shape)[1],d4=(B1.shape)[0],d5=(B1.shape)[0],d6=17).to(device) #d6 = planar-17/mesh-7/ocean-6/syn-13
  learning_rate = 1e-5
  optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate, weight_decay=1e-4)
  criterion = nn.CrossEntropyLoss()
  start = time.time()
  epochwise_training_loss, epochwise_test_loss, epochwise_training_acc  = [],[],[]
  epochwise_test_acc = []
  epochwise_val_acc = []
  epochwise_val_loss = []  	
  for i in range(0, 1500):
    for j in range(0,len(indices_all)//8):  
          optimizer.zero_grad()	
          indices = np.random.choice(train_index,8,replace=False) 
          Z_tr = Z_tr_[indices]
          indices_val = test_index 
          Z_val = Z_tr_[indices_val]
          x1_0_tr = x1_0[indices]
          x1_0_val = x1_0[indices_val]
          x1_1_tr = x1_1[indices]
          x1_1_val = x1_1[indices_val]
          x1_2_tr = x1_2[indices]
          x1_2_val = x1_2[indices_val]
          ys = network(torch.Tensor(x1_0_tr).to(device),torch.Tensor(x1_1_tr).to(device),torch.Tensor(x1_2_tr).to(device),torch.Tensor(B1).to(device),torch.Tensor(Z_tr).to(device).type(torch.LongTensor))
          acc_tr = evaluate(ys.cpu(),(y_tr[indices]).type(torch.FloatTensor))
          loss = criterion(torch.squeeze(ys).type(torch.FloatTensor), (y_tr)[indices].type(torch.LongTensor))
          loss.backward()
          optimizer.step()
          
    epochwise_training_loss.append(loss.item())
    epochwise_training_acc.append(acc_tr)
    print ("-----------epoch = %d | training_loss = %f |"%(i,loss.item()))
    print ("--------------------- | acc-tr =%f |"%(acc_tr.item()))
    network.eval()

    #validation
    ys_val = network(torch.Tensor(x1_0_val).to(device),torch.Tensor(x1_1_val).to(device),torch.Tensor(x1_2_val).to(device),torch.Tensor(B1).to(device), torch.Tensor(Z_val).to(device).type(torch.FloatTensor))
    acc_val = evaluate(ys_val.cpu(),y_tr[indices_val].type(torch.FloatTensor))
    l_val = criterion(torch.squeeze(ys_val).type(torch.FloatTensor), (torch.squeeze(y_tr)[indices_val]).type(torch.LongTensor))
    epochwise_val_loss.append(l_val.item()) 
    epochwise_val_acc.append(acc_val)
    print ("--------------------- | acc-val = %f |"%(acc_val.item()))
    ys_test = network(torch.Tensor(x1_0_test).to(device),torch.Tensor(x1_1_test).to(device),torch.Tensor(x1_2_test).to(device),torch.Tensor(B1).to(device),torch.Tensor(Z_test).to(device).type(torch.FloatTensor))
    acc_test = evaluate(ys_test.cpu(),y_test.type(torch.FloatTensor))
    l = criterion(torch.squeeze(ys_test).type(torch.FloatTensor), (y_test).type(torch.LongTensor))
    epochwise_test_loss.append(l.item())
    epochwise_test_acc.append(acc_test)
    print ("--------------------- | acc_test = %f |"%(acc_test.item()))
  timeit('process')    
  
  foldwise_training_loss.append(epochwise_training_loss)
  foldwise_training_acc.append(epochwise_training_acc)
  foldwise_val_loss.append(epochwise_val_loss)	  
  foldwise_val_acc.append(epochwise_val_acc)
  foldwise_test_loss.append(epochwise_test_loss)
  foldwise_test_acc.append(epochwise_test_acc) 
