import numpy as np
import copy
import matplotlib.pyplot as plt
import seaborn as sns
import h5py
import torch
import torch.nn as nn
import torch.nn.functional
import torch.utils.data as data
import numpy as np
import sys
import scnn.scnn
import scnn.chebyshev
import time
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import average_precision_score
from sklearn.metrics import auc
from sklearn.model_selection import KFold
#from random import sample
#from random import choices
sys.path.append('.')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
path = '/data/scp/'
data_tri = 'email_enron' #email_enron/contact_school/school_primary/ndc_classes/tags_math
data = 'email_enron' #email_enron/contact-high-school/contact-primary-school/NDC-classes/tags-math-sx
g = h5py.File(/triangles/"+data_tri+"/"+data_tri+"_80_100.jld", "r")
training_triangles = np.array(list(g[data_tri+'_80_100']))[:,0:3]
training_triangles_ = training_triangles-1

one_simplicies_order_train = np.load(path+'one_simplices_order_test_'+data+'.npy')
zero_simplicies_order_train = np.load(path+'zero_simplices_order_test_'+data+'.npy')

y_train_true = np.array(list(g[data_tri+'_80_100']))[:,3]
np.random.seed(1)
indices_1 = np.random.choice(np.squeeze(np.where(y_train_true==1)),len((y_train_true==1)))
indices_0 = np.where(y_train_true == 0)

L_relu = nn.LeakyReLU()
sig = nn.Sigmoid()
relu = nn.ReLU(inplace=False)
tanh = nn.Tanh()

class SA_MLP(nn.Module):
    def __init__(self,d1,d2,d3,d4,d5):
        super(SA_MLP,self).__init__()

        # Simplices of dimension 0.
        self.g0_0 = nn.Sequential(nn.Linear(d1,d2),nn.BatchNorm1d(d2),L_relu,nn.Linear(d2,d2),nn.BatchNorm1d(d2),L_relu,nn.Linear(d2,d2),nn.BatchNorm1d(d2),L_relu,nn.Linear(d2,d2),nn.BatchNorm1d(d2),L_relu).to(device)
        self.g0_1 = nn.Sequential(nn.Linear(2*d1,d2),nn.BatchNorm1d(d2),L_relu,nn.Linear(d2,d2),nn.BatchNorm1d(d2),L_relu,nn.Linear(d2,d2),nn.BatchNorm1d(d2),L_relu,nn.Linear(d2,d2),nn.BatchNorm1d(d2),L_relu).to(device)
        self.g0_2 = nn.Sequential(nn.Linear(6*d1,d2),nn.BatchNorm1d(d2),L_relu,nn.Linear(d2,d2),nn.BatchNorm1d(d2),L_relu).to(device)
	
	# Simplices of dimension 1.
        self.g1_0 = nn.Sequential(nn.Linear(d1,d2),nn.BatchNorm1d(d2),L_relu,nn.Linear(d2,d2),nn.BatchNorm1d(d2),L_relu,nn.Linear(d2,d2),nn.BatchNorm1d(d2),L_relu,nn.Linear(d2,d2),nn.BatchNorm1d(d2),L_relu).to(device)
        self.g1_1 = nn.Sequential(nn.Linear(4*d1,d2),nn.BatchNorm1d(d2),L_relu,nn.Linear(d2,d2),nn.BatchNorm1d(d2),L_relu,nn.Linear(d2,d2),nn.BatchNorm1d(d2),L_relu,nn.Linear(d2,d2),nn.BatchNorm1d(d2),L_relu).to(device)
        self.g1_2 = nn.Sequential(nn.Linear(14*d1,d2),nn.BatchNorm1d(d2),L_relu,nn.Linear(d2,d2),nn.BatchNorm1d(d2),L_relu,nn.Linear(d2,d2),nn.BatchNorm1d(d2),L_relu,nn.Linear(d2,d2),nn.BatchNorm1d(d2),L_relu).to(device)

        #Decoder.
        self.D = nn.Sequential(nn.Linear(2*2*3*d2,d4),nn.BatchNorm1d(d4),L_relu,nn.Linear(d4,d4),nn.BatchNorm1d(d4),L_relu,nn.Linear(d4,d4),nn.BatchNorm1d(d4),L_relu,nn.Linear(d4,d5),nn.BatchNorm1d(d5)).to(device)

    def forward(self, x0_0, x0_1, x1_0, x1_1, T_0, T_1):

        out0_1 = self.g0_0(x0_0) 
        out0_2 = self.g0_1(x0_1)

        out1_1 = self.g1_0(x1_0) 
        out1_2 = self.g1_1(x1_1) 

        
        xi_in01 = torch.zeros(len(T_0)//3,3*len(out1_1[0,:])).to(device)
        xi_in02 = torch.zeros(len(T_0)//3,3*len(out1_2[0,:])).to(device)
        xi_in11 = torch.zeros(len(T_0)//3,3*len(out1_1[0,:])).to(device)
        xi_in12 = torch.zeros(len(T_0)//3,3*len(out1_2[0,:])).to(device)
        
        for i in range(0,len(T_0)-3,3):
        	xi_in01[i//3,:] = torch.cat(([torch.squeeze(out0_1)[i+int(k)] for k in range(3)])).to(device)
        	xi_in02[i//3,:] = torch.cat(([torch.squeeze(out0_2)[i+int(k)] for k in range(3)])).to(device)
        	xi_in11[i//3,:] = torch.cat(([torch.squeeze(out1_1)[i+int(k)] for k in range(3)])).to(device)
        	xi_in12[i//3,:] = torch.cat(([torch.squeeze(out1_2)[i+int(k)] for k in range(3)])).to(device)

        
        xi_in0 = torch.cat((xi_in01,xi_in02),1)
        xi_in1 = torch.cat((xi_in11,xi_in12),1)
      
        phi_in = torch.cat((xi_in0,xi_in1),1).to(device)
        final_out = self.D(phi_in).to(device)  				
        
        return final_out

topdim = 3 # k = 0,1,2
boundaries = np.load(path+'boundary_matrices_tests_'+data+'.npy',allow_pickle=True)

cBs = [(boundaries[i].transpose()) for i in range(topdim+1)]
Bs = [(boundaries[i]) for i in range(topdim+1)]
ups = [(boundaries[i])@(boundaries[i].transpose()) for i in range(topdim+1)]
downs = [boundaries[i].transpose()@boundaries[i] for i in range(topdim+1)]
    
N0 = len(ups[0].todense())
N1 = len(ups[1].todense())
N2 = len(ups[2].todense())
N3 = len(ups[3].todense())
N4 = len(boundaries[3].todense().T)

x0_1 = abs(ups[0])@np.ones((N0,1)) 
x1_1 = abs(ups[1])@np.ones((N1,1)) + abs(downs[0])@np.ones((N1,1)) 
x2_1 = abs(ups[2])@np.ones((N2,1)) + abs(downs[1])@np.ones((N2,1))
x3_1 = abs(ups[3])@np.ones((N3,1)) + abs(downs[2])@np.ones((N3,1))
x0_2 = np.concatenate((abs(ups[0])@x0_1 , abs(Bs[0])@x1_1),axis=1) 
x1_2 = np.concatenate((abs(ups[1])@x1_1 , abs(downs[0])@x1_1 , abs(Bs[1])@x2_1 , abs(cBs[0])@x0_1),axis=1)
x2_2 = np.concatenate((abs(ups[2])@x2_1 , abs(downs[1])@x2_1 , abs(Bs[2])@x3_1 , abs(cBs[1])@x1_1),axis=1)
#x0_3 = np.concatenate(((abs(ups[0])@x0_2 , abs(Bs[0])@x1_2)),axis=1)
#x1_3 = np.concatenate((abs(ups[1])@x1_2 , abs(downs[0])@x1_2 , abs(Bs[1])@x2_2 , abs(cBs[0])@x0_2),axis=1)

batchwise_training_loss, batchwise_test_loss, batchwise_training_auc_score,batchwise_test_auc_score  = [],[],[],[]
batchwise_training_auc_pr = []
batchwise_test_auc_pr= []
batchwise_val_auc_pr =[]
batchwise_val_loss = []
batchwise_training_average_precision, batchwise_test_average_precision = [],[]  
np.random.seed(1)
y0 = np.squeeze(np.where(y_train_true==0))
y1 = np.squeeze(np.where(y_train_true==1))
np.random.seed(1)
indices_0_tr = np.random.choice(y0,int(0.6*len(np.squeeze(np.where(y_train_true==0)))),replace=False)
y0 = [i for i in y0 if i not in indices_0_tr]
np.random.seed(1)
indices_0_test = np.random.choice(y0,int(0.8*len(np.squeeze(np.where(y_train_true==0))))-int(0.6*len(np.squeeze(np.where(y_train_true==0)))),replace=False)
y0 = [i for i in y0 if i not in indices_0_test]
np.random.seed(1)
indices_0_val = np.random.choice(y0,int(len(np.squeeze(np.where(y_train_true==0))))-int(0.8*len(np.squeeze(np.where(y_train_true==0)))),replace=False)

np.random.seed(1)
indices_1_tr = np.random.choice(y1,int(0.6*len(np.squeeze(np.where(y_train_true==1)))),replace=False)
y1 = [i for i in y1 if i not in indices_1_tr]
np.random.seed(1)
indices_1_test = np.random.choice(y1,int(0.8*len(np.squeeze(np.where(y_train_true==1))))-int(0.6*len(np.squeeze(np.where(y_train_true==1)))),replace=False)
y1 = [i for i in y1 if i not in indices_1_test]
np.random.seed(1)
indices_1_val = np.random.choice(y1,int(len(np.squeeze(np.where(y_train_true==1))))-int(0.8*len(np.squeeze(np.where(y_train_true==1)))),replace=False)
np.random.seed(1)
indices_test = np.append(indices_0_test,indices_1_test)
indices_val = np.append(indices_0_val,indices_1_val)
indices_0_train = indices_0_tr
indices_1_train = indices_1_tr

for w in range(1):
  test_auc_pr_kfold = []

  network = SA_MLP(d1=1,d2=32,d3=32,d4=2*32,d5=1)
  learning_rate = 1e-2
  #optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate, weight_decay=1e-3)
  #checkpoint = torch.load('model_path')
  #network.load_state_dict(checkpoint['model_state_dict'])
  #optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  #epoch = checkpoint['epoch']
  #loss = checkpoint['loss']
  criterion = nn.BCEWithLogitsLoss(pos_weight = torch.tensor([3]))
  start = time.time()
  for i in range(0, 10):
    for j in range(0,len(indices_0_train)//32):  
          optimizer.zero_grad()

          indices = np.append(np.random.choice(indices_0_train,3*int(32/4),replace=False),np.random.choice(indices_1_train,int(32/4),replace=False))
          T = indices
          x0_1_ = np.zeros((3,len(x0_1[0,:])))
          x0_2_ = np.zeros((3,len(x0_2[0,:])))
          x1_1_ = np.zeros((3,len(x1_1[0,:])))
          x1_2_ = np.zeros((3,len(x1_2[0,:])))
          ys = torch.Tensor([]).to(device)
          zero_simplices_ind = []
          one_simplices_ind = []
          for ind in range(len(T)):        
          	for ind__ in range(3):
          		ind_ = zero_simplicies_order_train[T[ind]]
          		zero_simplices_ind.append(int(ind_[ind__]))
          	for ind__ in range(3):
          		ind_ = one_simplicies_order_train[T[ind]]
          		one_simplices_ind.append(int(ind_[ind__]))	

          ys = network(torch.Tensor(x0_1[zero_simplices_ind,:]).to(device),torch.Tensor(x0_2[zero_simplices_ind,:]).to(device),torch.Tensor(x1_1[one_simplices_ind,:]).to(device),torch.Tensor(x1_2[one_simplices_ind,:]).to(device),torch.Tensor(zero_simplices_ind).to(device),torch.Tensor(one_simplices_ind).to(device))
          precision_tr, recall_tr, thresholds_tr = precision_recall_curve((torch.tensor(list(g['school_primary_80_100']))[indices,3]).type(torch.FloatTensor).detach().numpy(),torch.squeeze(ys).type(torch.FloatTensor).detach().numpy())

          auc_precision_recall_tr = auc(recall_tr, precision_tr)
          loss = criterion(torch.squeeze(ys).type(torch.FloatTensor), (torch.tensor(list(g['school_primary_80_100']))[indices,3]).type(torch.FloatTensor))
          batchwise_training_loss.append(loss.item())
          batchwise_training_auc_pr.append(auc_precision_recall_tr)
          print ("-----------epoch = %d | training_loss = %f |"%(i,loss.item()))
          print ("--------------------- | auc-pr =%f |"%(auc_precision_recall_tr.item()))


          loss.backward()
          optimizer.step()

          network.eval()

          #validation
          T = indices_val
          x0_1_ = np.zeros((3,len(x0_1[0,:])))
          x0_2_ = np.zeros((3,len(x0_2[0,:])))
          x1_1_ = np.zeros((3,len(x1_1[0,:])))
          x1_2_ = np.zeros((3,len(x1_2[0,:])))
          ys_val = torch.Tensor([]).to(device)
          zero_simplices_ind = []
          one_simplices_ind = []
          for ind in range(len(T)):        
          	for ind__ in range(3):
          		ind_ = zero_simplicies_order_train[T[ind]]
          		zero_simplices_ind.append(int(ind_[ind__]))
          	for ind__ in range(3):
          		ind_ = one_simplicies_order_train[T[ind]]
          		one_simplices_ind.append(int(ind_[ind__]))
          ys_val = network(torch.Tensor(x0_1[zero_simplices_ind,:]).to(device),torch.Tensor(x0_2[zero_simplices_ind,:]).to(device),torch.Tensor(x1_1[one_simplices_ind,:]).to(device),torch.Tensor(x1_2[one_simplices_ind,:]).to(device),torch.Tensor(zero_simplices_ind).to(device),torch.Tensor(one_simplices_ind).to(device))
          l_val = criterion(torch.squeeze(ys_val).type(torch.FloatTensor), (torch.tensor(list(g['school_primary_80_100']))[indices_val,3]).type(torch.FloatTensor))
          batchwise_val_loss.append(l_val.item())
          precision_val, recall_val, thresholds_val = precision_recall_curve((torch.tensor(list(g['school_primary_80_100']))[indices_val,3]).type(torch.FloatTensor).detach().numpy(),torch.squeeze(ys_val).type(torch.FloatTensor).detach().numpy())
          auc_precision_recall_val = auc(recall_val, precision_val)
          batchwise_val_auc_pr.append(auc_precision_recall_val)
          print ("--------------------- | auc-pr_val = %f |"%(auc_precision_recall_val.item()))

          #testing
          T = indices_test
          x0_1_ = np.zeros((3,len(x0_1[0,:])))
          x0_2_ = np.zeros((3,len(x0_2[0,:])))
          x1_1_ = np.zeros((3,len(x1_1[0,:])))
          x1_2_ = np.zeros((3,len(x1_2[0,:])))
          zero_simplices_ind = []
          one_simplices_ind = []
          ys_test = torch.Tensor([]).to(device)
          for ind in range(len(T)):        
          	for ind__ in range(3):
          		ind_ = zero_simplicies_order_train[T[ind]]
          		zero_simplices_ind.append(int(ind_[ind__]))
          	for ind__ in range(3):
          		ind_ = one_simplicies_order_train[T[ind]]
          		one_simplices_ind.append(int(ind_[ind__]))	
          
          ys_test = network(torch.Tensor(x0_1[zero_simplices_ind,:]).to(device),torch.Tensor(x0_2[zero_simplices_ind,:]).to(device),torch.Tensor(x1_1[one_simplices_ind,:]).to(device),torch.Tensor(x1_2[one_simplices_ind,:]).to(device),torch.Tensor(zero_simplices_ind).to(device),torch.Tensor(one_simplices_ind).to(device))
          l = criterion(torch.squeeze(ys_test).type(torch.FloatTensor), (torch.tensor(list(g['school_primary_80_100']))[indices_test,3]).type(torch.FloatTensor))
          batchwise_test_loss.append(l.item())
          precision_test, recall_test, thresholds_test = precision_recall_curve((torch.tensor(list(g['school_primary_80_100']))[indices_test,3]).type(torch.FloatTensor).detach().numpy(),torch.squeeze(ys_test).type(torch.FloatTensor).detach().numpy())
          auc_precision_recall_test = auc(recall_test, precision_test)
          epochwise_test_auc_pr.append(auc_precision_recall_test)
          print("--------------------- | auc-pr_test = %f |"%(auc_precision_recall_test.item()))
          test_auc_pr_kfold.append(auc_precision_recall_test) 
 
#torch.save({
#            'epoch': i,
#            'model_state_dict': network.state_dict(),
#            'optimizer_state_dict': optimizer.state_dict(),
#            'loss': loss,
#            }, 'model_path')            



