#!/usr/bin/env python
# coding: utf-8

# In[ ]:


import sys
sys.path.append('./codes/')
import time
from config import args
save_map = "LISA_TEST_LOGS/TREE_GRID/"

args.dataset='syn4'
args.elr = 0.001
args.eepochs = 30
args.coff_size = 0.01
args.budget = -1.0
args.coff_ent = 1.0

# import tensorflow as tf
from utils import *
from models import GCN2 as GCN
from metrics import *
import numpy as np
from Extractor import Extractor
from Explainer import Explainer
from scipy.sparse import coo_matrix,csr_matrix
import networkx as nx
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score

import torch
import torch.optim


def plot(node,label, iteration):
    after_adj_dense = explainer.masked_adj.cpu().detach().numpy()
    after_adj = coo_matrix(after_adj_dense)

    rcd = np.concatenate([np.expand_dims(after_adj.row,-1),np.expand_dims(after_adj.col,-1),np.expand_dims(after_adj.data,-1)],-1)
    pos_edges = []
    filter_edges = []
    edge_weights = after_adj.data
    sorted_edge_weights = np.sort(edge_weights)
    thres_index = max(int(edge_weights.shape[0]-24),0)
    thres = sorted_edge_weights[thres_index]
    filter_thres_index = min(thres_index,max(int(edge_weights.shape[0]-edge_weights.shape[0]/2),edge_weights.shape[0]-100))
    # filter_thres_index = min(thres_index,max(int(edge_weights.shape[0]-edge_weights.shape[0]/4),edge_weights.shape[0]-100))
    filter_thres = sorted_edge_weights[filter_thres_index]
    filter_nodes =set()

    for r,c,d in rcd:
        r = int(r)
        c = int(c)
        if d>=thres:
            pos_edges.append((r,c))
        if d>filter_thres:
            filter_edges.append((r,c))
            filter_nodes.add(r)
            filter_nodes.add(c)

    num_nodes = sub_adj.shape[0]
    G = nx.Graph()
    G.add_nodes_from(range(num_nodes))
    G.add_edges_from(filter_edges)

    for cc in nx.connected_components(G):
        if 0 in cc:
            G = G.subgraph(cc).copy()
            break

    pos_edges = [(u, v) for (u, v) in pos_edges if u in G.nodes() and v in G.nodes()]
    pos = nx.kamada_kawai_layout(G)

    colors = ['orange', 'red', 'green', 'blue', 'maroon', 'brown', 'darkslategray', 'paleturquoise', 'darksalmon',
              'slategray', 'mediumseagreen', 'mediumblue', 'orchid', ]
    if args.dataset=='syn3':
        colors = ['orange', 'blue']


    if args.dataset=='syn4':
        colors = ['orange', 'black','black','black','blue']


    # nodes
    labels = label#.numpy()
    max_label = np.max(labels)+1

    nmb_nodes = after_adj_dense.shape[0]
    label2nodes= []
    for i in range(max_label):
    	label2nodes.append([])
    for i in range(nmb_nodes):
    	label2nodes[labels[i]].append(i)

    for i in range(max_label):
        node_filter = []
        for j in range(len(label2nodes[i])):
            if label2nodes[i][j] in G.nodes():
                node_filter.append(label2nodes[i][j])
        nx.draw_networkx_nodes(G, pos,
                               nodelist=node_filter,
                               node_color=colors[i % len(colors)],
                               node_size=500)

    nx.draw_networkx_nodes(G, pos,
                           nodelist=[0],
                           node_color=colors[labels[0]],
                           node_size=1000)

    nx.draw_networkx_edges(G, pos, width=7, alpha=0.5, edge_color='grey')

    nx.draw_networkx_edges(G, pos,
                           edgelist=pos_edges,
                           width=7, alpha=0.5)


    plt.axis('off')
#     plt.show()
    plt.savefig(save_map + str(iteration) + "/" + str(node) + ".png")
    plt.clf()


# In[ ]:



reals = []
preds = []
def acc(sub_adj,sub_edge_label):
    real = []
    pred = []
    sub_edge_label = sub_edge_label.todense()
    mask = explainer.masked_adj.cpu().detach().numpy()
    for r,c in list(zip(sub_adj.row,sub_adj.col)):
        d = sub_edge_label[r,c] + sub_edge_label[c,r]
        if d==0:
            real.append(0)
        else:
            real.append(1)
        pred.append(mask[r][c]+mask[c][r])
    reals.extend(real)
    preds.extend(pred)

    if len(np.unique(real))==1 or len(np.unique(pred))==1:
        return -1
    return roc_auc_score(real,pred)


# In[ ]:


def train(iteration):
    t0 = args.coff_t0
    t1 = args.coff_te
    epochs = args.eepochs
    model.eval()
    explainer.train()
    best_auc = 0
    for epoch in range(epochs):
        train_accs = []
        loss = 0
        pred_loss = 0
        lap_loss = 0
        tmp = float(t0*np.power(t1/t0,epoch/epochs))
        tmp = 5.0
        
        for i in range(len(allnodes)):
            with torch.no_grad():
                output = model((sub_features[i],sub_support_tensors[i]), training=False)
            train_acc = accuracy(output, sub_label_tensors[i])
            train_accs.append(float(train_acc))
            pred_label = torch.argmax(output, 1)

            x = sub_features[i]
            adj = sub_adjs[i]
            nodeid = 0
            embed = sub_embeds[i]
            pred = explainer((x,adj,nodeid,embed,tmp),training=True)
            l,pl,ll = explainer.loss(pred, pred_label, sub_label_tensor, 0)
            loss = loss + l
            pred_loss = pred_loss + pl
            lap_loss = lap_loss + ll
                
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_value_(explainer.parameters(), clip_value_max)
        optimizer.step()
        
        global reals
        global preds
        reals = []
        preds = []
        for node in allnodes:
            h = explain_test(node, iteration, needplot=True)
        auc = roc_auc_score(reals, preds)
        explainer.train()
        
        if auc > best_auc:
            best_auc = auc
            torch.save(explainer.state_dict(), f'model_weights/Tree-Grid_BEST.pt')
            torch.save(explainer.state_dict(), save_map + str(iteration) + "/" + 'Tree_Grid_BEST.pt')
#             best_state_dict = explainer.state_dict()

#     torch.save(best_state_dict, f'model_weights/Tree-Grid_BEST.pt')
#     torch.save(best_state_dict, save_map + str(iteration) + "/" + 'Tree_Grid_BEST.pt')
    torch.save(explainer.state_dict(), f'model_weights/Tree-Grid_LAST.pt')
    torch.save(explainer.state_dict(), save_map + str(iteration) + "/" + 'Tree_Grid_LAST.pt')



def explain_test(node,iteration,needplot=True):
    newid = remap[node]
    sub_adj, sub_feature, sub_embed, sub_label, sub_edge_label =  sub_adjs[newid],sub_features[newid],sub_embeds[newid],sub_labels[newid],sub_edge_labels[newid]
    
    nodeid = 0
    explainer.eval()
    explainer((sub_feature,sub_adj,nodeid,sub_embed,1.0),training=False)
    explainer.train()
    label = np.argmax(sub_label,-1)
    if needplot:
        plot(node,label,iteration)
    acc(sub_adj,sub_edge_label)


if __name__ == '__main__':
    for iteration in range(10):
        device = "cuda:0" if torch.cuda.is_available() else "cpu"
        print("Starting iteration: {}".format(iteration))

        #CELL 1
        with open('./dataset/' + args.dataset + '.pkl', 'rb') as fin:
            adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask, edge_label_matrix  = pkl.load(fin)

        adj = csr_matrix(adj)
        support = preprocess_adj(adj)

        features_tensor = torch.tensor(features).type(torch.float32)
        i = torch.LongTensor([*support[0]])
        v = torch.FloatTensor([*support[1]])
        # LET OP: i moet getransposed worden om sparse tensor te maken met pytorch
        support_tensor = torch.sparse.FloatTensor(i.t(), v, torch.Size([*support[2]]))
        support_tensor = support_tensor.type(torch.float32)

        model = GCN(input_dim=features.shape[1], output_dim=y_train.shape[1], device=device)
        model.to(device)
        model.load_state_dict(torch.load('model_weights/GCN_syn4_BEST.pt'))

        explainer = Explainer(model=model)
        explainer.to(device)
        embeds = model.embedding((features_tensor,support_tensor)).cpu().detach().numpy()

        all_label = np.logical_or(y_train,np.logical_or(y_val,y_test))
        single_label = np.argmax(all_label,axis=-1)
        hops = len(args.hiddens.split('-'))
        extractor = Extractor(adj,features,edge_label_matrix,embeds,all_label,hops)
        if args.setting==1: # setting from their original paper
            if args.dataset=='syn3':
                allnodes = [i for i in range(511,871,6)]
            elif args.dataset=='syn4':
                allnodes = [i for i in range(511,800,1)]
            else:
                allnodes = [i for i in range(400,700,5)]
        elif args.setting==2:
            allnodes = [i for i in range(single_label.shape[0]) if single_label[i] ==1]
        elif args.setting==3:
            if args.dataset == 'syn2':
                allnodes = [i for i in range(single_label.shape[0]) if single_label[i] != 0 and single_label[i] != 4]
            else:
                allnodes = [i for i in range(single_label.shape[0]) if single_label[i] != 0]

        optimizer = torch.optim.Adam(explainer.parameters(), lr=args.elr)
        clip_value_min = -2.0
        clip_value_max = 2.0

        sub_support_tensors = []
        sub_label_tensors = []
        sub_features = []
        sub_embeds = []
        sub_adjs = []
        sub_edge_labels = []
        sub_labels = []
        remap = {}


        #CELL 2
        for node in allnodes:
            sub_adj,sub_feature, sub_embed, sub_label,sub_edge_label_matrix = extractor.subgraph(node)
            remap[node]=len(sub_adjs)
            sub_support = preprocess_adj(sub_adj)
            i = torch.LongTensor([*sub_support[0]])
            v = torch.FloatTensor([*sub_support[1]])
            # LET OP: i moet getransposed worden om sparse tensor te maken met pytorch
            sub_support_tensor = torch.sparse.FloatTensor(i.t(), v, torch.Size([*sub_support[2]])).type(torch.float32)
            sub_label_tensor = torch.Tensor(sub_label).type(torch.float32)

            sub_adjs.append(sub_adj)
            sub_features.append(torch.Tensor(sub_feature).type(torch.float32))
            sub_embeds.append(sub_embed)
            sub_labels.append(sub_label)
            sub_edge_labels.append(sub_edge_label_matrix)
            sub_label_tensors.append(sub_label_tensor)
            sub_support_tensors.append(sub_support_tensor)
        best_auc = 0.0

        random.seed(iteration)
        np.random.seed(iteration)
        torch.manual_seed(iteration)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(iteration)
            torch.cuda.manual_seed_all(iteration)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False

        f = open(save_map + str(iteration) + "/" + "LOG.txt", "w")
        train(iteration)

        explainer.load_state_dict(torch.load(save_map + str(iteration) + "/" + 'Tree_Grid_BEST.pt'))

        reals= []
        preds = []

        tik = time.time()
        for node in allnodes:
            h = explain_test(node, iteration, needplot=True)
            auc = roc_auc_score(reals, preds)
            tok = time.time()
            f.write("node,{}".format(node) + ",auc,{}".format(auc) + ",time,{}".format(tok-tik) + "\n")

        tok = time.time()
        f.write("time,{}".format(tok-tik) + "\n")

        f.close()

