#!/usr/bin/env python
# coding: utf-8

# In[ ]:


import sys
CUDA_LAUNCH_BLOCKING=1
sys.path.append('./codes/')
import time
from config import args
save_map = "LISA_TEST_LOGS/BA_SHAPES/"

args.dataset='syn1'
args.elr = 0.003
args.eepochs = 10
args.coff_size = 0.05
args.budget = -1
args.coff_ent = 1.0

from tqdm import tqdm
# import tensorflow as tf
from utils import *
from models import GCN2 as GCN
from metrics import *
import numpy as np
from Extractor import Extractor
from Explainer import Explainer
from scipy.sparse import coo_matrix,csr_matrix
import networkx as nx
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score

import torch
import torch.optim


def plot(node,label, iteration):
    after_adj_dense = explainer.masked_adj.cpu().detach().numpy()
    after_adj = coo_matrix(after_adj_dense)

    rcd = np.concatenate([np.expand_dims(after_adj.row,-1),np.expand_dims(after_adj.col,-1),np.expand_dims(after_adj.data,-1)],-1)
    pos_edges = []
    filter_edges = []
    edge_weights = after_adj.data
    sorted_edge_weights = np.sort(edge_weights)
    thres_index = max(int(edge_weights.shape[0]-12),0)
    thres = sorted_edge_weights[thres_index]
    filter_thres_index = min(thres_index,max(int(edge_weights.shape[0]-edge_weights.shape[0]/2),edge_weights.shape[0]-100))
    # filter_thres_index = min(thres_index,max(int(edge_weights.shape[0]-edge_weights.shape[0]/4),edge_weights.shape[0]-100))
    filter_thres = sorted_edge_weights[filter_thres_index]
    filter_nodes =set()

    for r,c,d in rcd:
        r = int(r)
        c = int(c)
        if d>=thres:
            pos_edges.append((r,c))
        if d>filter_thres:
            filter_edges.append((r,c))
            filter_nodes.add(r)
            filter_nodes.add(c)

    num_nodes = sub_adj.shape[0]
    G = nx.Graph()
    G.add_nodes_from(range(num_nodes))
    G.add_edges_from(filter_edges)

    for cc in nx.connected_components(G):
        if 0 in cc:
            G = G.subgraph(cc).copy()
            break

    pos_edges = [(u, v) for (u, v) in pos_edges if u in G.nodes() and v in G.nodes()]
    pos = nx.kamada_kawai_layout(G)

    colors = ['orange', 'red', 'green', 'blue', 'maroon', 'brown', 'darkslategray', 'paleturquoise', 'darksalmon',
              'slategray', 'mediumseagreen', 'mediumblue', 'orchid', ]
    if args.dataset=='syn3':
        colors = ['orange', 'blue']


    if args.dataset=='syn4':
        colors = ['orange', 'black','black','black','blue']


    # nodes
    labels = label#.numpy()
    max_label = np.max(labels)+1

    nmb_nodes = after_adj_dense.shape[0]
    label2nodes= []
    for i in range(max_label):
    	label2nodes.append([])
    for i in range(nmb_nodes):
    	label2nodes[labels[i]].append(i)

    for i in range(max_label):
        node_filter = []
        for j in range(len(label2nodes[i])):
            if label2nodes[i][j] in G.nodes():
                node_filter.append(label2nodes[i][j])
        nx.draw_networkx_nodes(G, pos,
                               nodelist=node_filter,
                               node_color=colors[i % len(colors)],
                               node_size=500)

    nx.draw_networkx_nodes(G, pos,
                           nodelist=[0],
                           node_color=colors[labels[0]],
                           node_size=1000)

    nx.draw_networkx_edges(G, pos, width=7, alpha=0.5, edge_color='grey')

    nx.draw_networkx_edges(G, pos,
                           edgelist=pos_edges,
                           width=7, alpha=0.5)


    plt.axis('off')
#     plt.show()
    plt.savefig(save_map + str(iteration) + "/" + str(node) + ".png")
    plt.clf()


# In[ ]:


def train(iteration):
    t0 = args.coff_t0
    t1 = args.coff_te
    epochs = args.eepochs
    model.eval()
    explainer.train()
    best_auc = 0
    for epoch in range(epochs):
        train_accs = []
        loss = 0
        pred_loss = 0
        lap_loss = 0
        tmp = float(1.0*np.power(0.05,epoch/epochs))
        for i in range(len(allnodes)):
            with torch.no_grad():
                output = model((sub_features[i],sub_support_tensors[i]), training=False)
            train_acc = accuracy(output, sub_label_tensors[i])
            train_accs.append(float(train_acc))
            pred_label = torch.argmax(output, 1)

            x = sub_features[i]
            adj = sub_adjs[i]
            nodeid = 0
            embed = torch.Tensor(sub_embeds[i])
            pred = explainer((x,adj,nodeid,embed,tmp), training=True)
            l,pl,ll = explainer.loss(pred, pred_label, sub_label_tensor, 0)
            loss = loss + l
            pred_loss = pred_loss + pl
            lap_loss = lap_loss + ll

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_value_(explainer.parameters(), clip_value_max)
        optimizer.step()

        global reals
        global preds
        reals = []
        preds = []
        for node in allnodes:
            h = explain_test(explainer, node, iteration, needplot=True)
        auc = roc_auc_score(reals, preds)
        explainer.train()
        
        if auc > best_auc:
            best_auc = auc
            torch.save(explainer.state_dict(), f'model_weights/BA-shapes_BEST.pt')
            torch.save(explainer.state_dict(), save_map + str(iteration) + "/" + 'BA-shapes_BEST.pt')
#             best_state_dict = explainer.state_dict().clone()
            
#     torch.save(best_state_dict, f'model_weights/BA-shapes_BEST.pt')
#     torch.save(best_state_dict, save_map + str(iteration) + "/" + 'BA-shapes_BEST.pt')
    torch.save(explainer.state_dict(), f'model_weights/BA-shapes_LAST.pt')
    torch.save(explainer.state_dict(), save_map + str(iteration) + "/" + 'BA-shapes_LAST.pt')

reals = []
preds = []

def acc(sub_adj,sub_edge_label):
    real = []
    pred = []
    sub_edge_label = sub_edge_label.todense()
    mask = explainer.masked_adj.cpu().detach().numpy()
    for r,c in list(zip(sub_adj.row,sub_adj.col)):
        d = sub_edge_label[r,c] + sub_edge_label[c,r]
        if d==0:
            real.append(0)
        else:
            real.append(1)
        pred.append(mask[r][c]+mask[c][r])
    reals.extend(real)
    preds.extend(pred)

    if len(np.unique(real))==1 or len(np.unique(pred))==1:
        return -1
    return roc_auc_score(real,pred)


def explain_test(explainer,node,iteration,needplot=True):
    newid = remap[node]
    sub_adj, sub_feature, sub_embed, sub_label, sub_edge_label =  sub_adjs[newid],sub_features[newid],sub_embeds[newid],sub_labels[newid],sub_edge_labels[newid]
    
    explainer.eval()
    nodeid = 0
    explainer((sub_feature,sub_adj,nodeid,sub_embed,1.0))
    label = np.argmax(sub_label,-1)
    if needplot:
        plot(node,label,iteration)
    acc(sub_adj,sub_edge_label)


# In[ ]:


for iteration in range(10):
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    print("Starting iteration: {}".format(iteration))

    #CELL 1
    with open('./dataset/' + args.dataset + '.pkl', 'rb') as fin:
        adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask, edge_label_matrix  = pkl.load(fin)

    adj = csr_matrix(adj)
    support = preprocess_adj(adj)

    features_tensor = torch.tensor(features).type(torch.float32)
    i = torch.LongTensor([*support[0]])
    v = torch.FloatTensor([*support[1]])
    # LET OP: i moet getransposed worden om sparse tensor te maken met pytorch
    support_tensor = torch.sparse.FloatTensor(i.t(), v, torch.Size([*support[2]]))
    support_tensor = support_tensor.type(torch.float32)

    model = GCN(input_dim=features.shape[1], output_dim=y_train.shape[1], device=device)
    model.to(device)
    model.load_state_dict(torch.load('model_weights/GCN_syn1_BEST.pt'))

    explainer = Explainer(model=model)
    explainer.to(device)
    embeds = model.embedding((features_tensor,support_tensor)).cpu().detach().numpy()

    all_label = np.logical_or(y_train,np.logical_or(y_val,y_test))
    single_label = np.argmax(all_label,axis=-1)
    hops = len(args.hiddens.split('-'))
    extractor = Extractor(adj,features,edge_label_matrix,embeds,all_label,hops)
    if args.setting==1:
        if args.dataset=='syn3':
            allnodes = [i for i in range(511,871,6)]
        elif args.dataset=='syn4':
            allnodes = [i for i in range(511,800,1)]
        else:
            allnodes = [i for i in range(400,700,5)] # setting from their original paper
    elif args.setting==2:
        allnodes = [i for i in range(single_label.shape[0]) if single_label[i] ==1]
    elif args.setting==3:
        if args.dataset == 'syn2':
            allnodes = [i for i in range(single_label.shape[0]) if single_label[i] != 0 and single_label[i] != 4]
        else:
            allnodes = [i for i in range(single_label.shape[0]) if single_label[i] != 0]

    optimizer = torch.optim.Adam(explainer.parameters(), lr=args.elr)
    clip_value_min = -2.0
    clip_value_max = 2.0

    sub_support_tensors = []
    sub_label_tensors = []
    sub_features = []
    sub_embeds = []
    sub_adjs = []
    sub_edge_labels = []
    sub_labels = []
    remap = {}
    
    
    #CELL 2
    for node in allnodes:
        sub_adj,sub_feature, sub_embed, sub_label,sub_edge_label_matrix = extractor.subgraph(node)
        remap[node]=len(sub_adjs)
        sub_support = preprocess_adj(sub_adj)
        i = torch.LongTensor([*sub_support[0]])
        v = torch.FloatTensor([*sub_support[1]])
        # LET OP: i moet getransposed worden om sparse tensor te maken met pytorch
        sub_support_tensor = torch.sparse.FloatTensor(i.t(), v, torch.Size([*sub_support[2]])).type(torch.float32) 
        sub_label_tensor = torch.Tensor(sub_label).type(torch.float32)

        sub_adjs.append(sub_adj)
        sub_features.append(torch.Tensor(sub_feature).type(torch.float32))
        sub_embeds.append(sub_embed)
        sub_labels.append(sub_label)
        sub_edge_labels.append(sub_edge_label_matrix)
        sub_label_tensors.append(sub_label_tensor)
        sub_support_tensors.append(sub_support_tensor)
    best_auc = 0.0
    
    random.seed(iteration)
    np.random.seed(iteration)
    torch.manual_seed(iteration)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(iteration)
        torch.cuda.manual_seed_all(iteration)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    
    f = open(save_map + str(iteration) + "/" + "LOG.txt", "w")
    train(iteration)
    
    explainer.load_state_dict(torch.load(save_map + str(iteration) + "/" + 'BA-shapes_BEST.pt'))

    reals = []
    preds = []
    
    tik = time.time()
    for node in allnodes:
        h = explain_test(explainer,node, iteration, needplot=True)
        auc = roc_auc_score(reals, preds)
        tok = time.time()
        f.write("node,{}".format(node) + ",auc,{}".format(auc) + ",time,{}".format(tok-tik) + "\n")
    
    tok = time.time()
    f.write("time,{}".format(tok-tik) + "\n")
        
    f.close()

