#!/usr/bin/env python
# coding: utf-8

# In[1]:


import sys
sys.path.append('./codes/')
import time
from config import args
save_map = "LISA_TEST_LOGS/TREE_CYCLE/"

args.dataset='syn3'
args.elr = 0.003
args.eepochs = 20
args.coff_size = 0.0001
args.budget = -1
args.coff_ent = 0.01

# import tensorflow as tf
from utils import *
from models import GCN2 as GCN
from metrics import *
import numpy as np
from Extractor import Extractor
from Explainer import Explainer
from scipy.sparse import coo_matrix,csr_matrix
import networkx as nx
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score

import torch
import torch.optim


# In[2]:


# with open('./dataset/' + args.dataset + '.pkl', 'rb') as fin:
#     adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask, edge_label_matrix  = pkl.load(fin)

# adj = csr_matrix(adj)
# support = preprocess_adj(adj)

# features_tensor = torch.tensor(features).type(torch.float32)
# i = torch.LongTensor([*support[0]])
# v = torch.FloatTensor([*support[1]])
# # LET OP: i moet getransposed worden om sparse tensor te maken met pytorch
# support_tensor = torch.sparse.FloatTensor(i.t(), v, torch.Size([*support[2]]))
# support_tensor = support_tensor.type(torch.float32)

# model = GCN(input_dim=features.shape[1], output_dim=y_train.shape[1])
# model.load_state_dict(torch.load('model_weights/GCN_syn3_LAST.pt'))

# explainer = Explainer(model=model)
# embeds = model.embedding((features_tensor,support_tensor)).detach().numpy()

# all_label = np.logical_or(y_train,np.logical_or(y_val,y_test))
# single_label = np.argmax(all_label,axis=-1)
# hops = len(args.hiddens.split('-'))
# extractor = Extractor(adj,features,edge_label_matrix,embeds,all_label,hops)
# if args.setting==1:
#     if args.dataset=='syn3':
#         allnodes = [i for i in range(511,871,6)]
#     elif args.dataset=='syn4':
#         allnodes = [i for i in range(511,800,1)]
#     else:
#         allnodes = [i for i in range(400,700,5)] # setting from their original paper
# elif args.setting==2:
#     allnodes = [i for i in range(single_label.shape[0]) if single_label[i] ==1]
# elif args.setting==3:
#     if args.dataset == 'syn2':
#         allnodes = [i for i in range(single_label.shape[0]) if single_label[i] != 0 and single_label[i] != 4]
#     else:
#         allnodes = [i for i in range(single_label.shape[0]) if single_label[i] != 0]

# optimizer = torch.optim.Adam(explainer.parameters(), lr=args.elr)
# clip_value_min = -2.0
# clip_value_max = 2.0

# sub_support_tensors = []
# sub_label_tensors = []
# sub_features = []
# sub_embeds = []
# sub_adjs = []
# sub_edge_labels = []
# sub_labels = []
# remap = {}


# In[3]:


# for node in allnodes:
#     sub_adj,sub_feature, sub_embed, sub_label,sub_edge_label_matrix = extractor.subgraph(node)
#     remap[node]=len(sub_adjs)
#     sub_support = preprocess_adj(sub_adj)
#     i = torch.LongTensor([*sub_support[0]])
#     v = torch.FloatTensor([*sub_support[1]])
#     # LET OP: i moet getransposed worden om sparse tensor te maken met pytorch
#     sub_support_tensor = torch.sparse.FloatTensor(i.t(), v, torch.Size([*sub_support[2]])).type(torch.float32) 
#     sub_label_tensor = torch.Tensor(sub_label).type(torch.float32)

#     sub_adjs.append(sub_adj)
#     sub_features.append(torch.Tensor(sub_feature).type(torch.float32))
#     sub_embeds.append(sub_embed)
#     sub_labels.append(sub_label)
#     sub_edge_labels.append(sub_edge_label_matrix)
#     sub_label_tensors.append(sub_label_tensor)
#     sub_support_tensors.append(sub_support_tensor)
# best_auc = 0.0


# In[4]:



def plot(node,label, iteration):
    after_adj_dense = explainer.masked_adj.detach().numpy()
    after_adj = coo_matrix(after_adj_dense)

    rcd = np.concatenate([np.expand_dims(after_adj.row,-1),np.expand_dims(after_adj.col,-1),np.expand_dims(after_adj.data,-1)],-1)
    pos_edges = []
    filter_edges = []
    edge_weights = after_adj.data
    sorted_edge_weights = np.sort(edge_weights)
    thres_index = max(int(edge_weights.shape[0]-12),0)
    thres = sorted_edge_weights[thres_index]
    filter_thres_index = min(thres_index,max(int(edge_weights.shape[0]-edge_weights.shape[0]/2),edge_weights.shape[0]-100))
    # filter_thres_index = min(thres_index,max(int(edge_weights.shape[0]-edge_weights.shape[0]/4),edge_weights.shape[0]-100))
    filter_thres = sorted_edge_weights[filter_thres_index]
    filter_nodes =set()

    for r,c,d in rcd:
        r = int(r)
        c = int(c)
        if d>=thres:
            pos_edges.append((r,c))
        if d>filter_thres:
            filter_edges.append((r,c))
            filter_nodes.add(r)
            filter_nodes.add(c)

    num_nodes = sub_adj.shape[0]
    G = nx.Graph()
    G.add_nodes_from(range(num_nodes))
    G.add_edges_from(filter_edges)

    for cc in nx.connected_components(G):
        if 0 in cc:
            G = G.subgraph(cc).copy()
            break

    pos_edges = [(u, v) for (u, v) in pos_edges if u in G.nodes() and v in G.nodes()]
    pos = nx.kamada_kawai_layout(G)

    colors = ['orange', 'red', 'green', 'blue', 'maroon', 'brown', 'darkslategray', 'paleturquoise', 'darksalmon',
              'slategray', 'mediumseagreen', 'mediumblue', 'orchid', ]
    if args.dataset=='syn3':
        colors = ['orange', 'blue']


    if args.dataset=='syn4':
        colors = ['orange', 'black','black','black','blue']


    # nodes
    labels = label#.numpy()
    max_label = np.max(labels)+1

    nmb_nodes = after_adj_dense.shape[0]
    label2nodes= []
    for i in range(max_label):
    	label2nodes.append([])
    for i in range(nmb_nodes):
    	label2nodes[labels[i]].append(i)

    for i in range(max_label):
        node_filter = []
        for j in range(len(label2nodes[i])):
            if label2nodes[i][j] in G.nodes():
                node_filter.append(label2nodes[i][j])
        nx.draw_networkx_nodes(G, pos,
                               nodelist=node_filter,
                               node_color=colors[i % len(colors)],
                               node_size=500)

    nx.draw_networkx_nodes(G, pos,
                           nodelist=[0],
                           node_color=colors[labels[0]],
                           node_size=1000)

    nx.draw_networkx_edges(G, pos, width=7, alpha=0.5, edge_color='grey')

    nx.draw_networkx_edges(G, pos,
                           edgelist=pos_edges,
                           width=7, alpha=0.5)


    plt.axis('off')
#     plt.show()
    plt.savefig(save_map + str(iteration) + "/" + str(node) + ".png")
    plt.clf()


# In[5]:


def train(iteration):
    t0 = args.coff_t0
    t1 = args.coff_te
    epochs = args.eepochs
    model.eval()
    explainer.train()
    best_auc = 0

    for epoch in range(epochs):
        train_accs = []
        loss = 0
        pred_loss = 0
        lap_loss = 0
        tmp = float(t0*np.power(t1/t0,epoch/epochs))
        tmp = 5.0
        for i in range(len(allnodes)):
            with torch.no_grad():
                output = model((sub_features[i],sub_support_tensors[i]), training=False)

            train_acc = accuracy(output, sub_label_tensors[i])
            train_accs.append(float(train_acc))
            pred_label = torch.argmax(output, 1)

            x = sub_features[i]
            adj = sub_adjs[i]
            nodeid = 0
            embed = torch.Tensor(sub_embeds[i])
            pred = explainer((x,adj,nodeid,embed,tmp),training=True)
            l,pl,ll = explainer.loss(pred, pred_label, sub_label_tensor, 0)
            loss = loss + l
            pred_loss = pred_loss + pl
            lap_loss = lap_loss + ll
                    
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_value_(explainer.parameters(), clip_value_max)
        optimizer.step()

        global reals
        global preds
        reals = []
        preds = []
        for node in allnodes:
            h = explain_test(node, iteration, needplot=False)
        auc = roc_auc_score(reals, preds)
        explainer.train()
        print("epoch,{}".format(epoch) + ",auc,{}".format(auc))
        if auc > best_auc:
            print("better auc")
            best_auc = auc
            best_state_dict = explainer.state_dict()

    torch.save(best_state_dict, f'model_weights/Tree-Cycles_BEST.pt')
    torch.save(best_state_dict, save_map + str(iteration) + "/" + 'Tree_Cycles_BESt.pt')
    torch.save(explainer.state_dict(), f'model_weights/Tree-Cycles_LAST.pt')
    torch.save(explainer.state_dict(), save_map + str(iteration) + "/" + 'Tree_Cycles_LAST.pt')
            
reals = []
preds = []
def acc(sub_adj,sub_edge_label):
    real = []
    pred = []
    sub_edge_label = sub_edge_label.todense()
    mask = explainer.masked_adj.detach().numpy()
    for r,c in list(zip(sub_adj.row,sub_adj.col)):
        d = sub_edge_label[r,c] + sub_edge_label[c,r]
        if d==0:
            real.append(0)
        else:
            real.append(1)
        pred.append(mask[r][c]+mask[c][r])
    reals.extend(real)
    preds.extend(pred)

    if len(np.unique(real))==1 or len(np.unique(pred))==1:
        return -1
    return roc_auc_score(real,pred)


# In[6]:


def explain_test(node,iteration,needplot=True):
    newid = remap[node]
    sub_adj, sub_feature, sub_embed, sub_label, sub_edge_label =  sub_adjs[newid],sub_features[newid],sub_embeds[newid],sub_labels[newid],sub_edge_labels[newid]

    nodeid = 0
    explainer.eval()
    explainer((sub_feature,sub_adj,nodeid,sub_embed,1.0),training=False)
    label = np.argmax(sub_label,-1)
    if needplot:
        plot(node,label,iteration)
    acc(sub_adj,sub_edge_label)


# In[7]:


for iteration in range(10):
    print("Starting iteration: {}".format(iteration))
        
    # CELL 1
    with open('./dataset/' + args.dataset + '.pkl', 'rb') as fin:
        adj, features, y_train, y_val, y_test, train_mask, val_mask, test_mask, edge_label_matrix  = pkl.load(fin)

    adj = csr_matrix(adj)
    support = preprocess_adj(adj)

    features_tensor = torch.tensor(features).type(torch.float32)
    i = torch.LongTensor([*support[0]])
    v = torch.FloatTensor([*support[1]])
    # LET OP: i moet getransposed worden om sparse tensor te maken met pytorch
    support_tensor = torch.sparse.FloatTensor(i.t(), v, torch.Size([*support[2]]))
    support_tensor = support_tensor.type(torch.float32)

    model = GCN(input_dim=features.shape[1], output_dim=y_train.shape[1])
    model.load_state_dict(torch.load('model_weights/GCN_syn3_BEST.pt'))

    explainer = Explainer(model=model)
    embeds = model.embedding((features_tensor,support_tensor)).detach().numpy()

    all_label = np.logical_or(y_train,np.logical_or(y_val,y_test))
    single_label = np.argmax(all_label,axis=-1)
    hops = len(args.hiddens.split('-'))
    extractor = Extractor(adj,features,edge_label_matrix,embeds,all_label,hops)
    if args.setting==1:
        if args.dataset=='syn3':
            allnodes = [i for i in range(511,871,6)]
        elif args.dataset=='syn4':
            allnodes = [i for i in range(511,800,1)]
        else:
            allnodes = [i for i in range(400,700,5)] # setting from their original paper
    elif args.setting==2:
        allnodes = [i for i in range(single_label.shape[0]) if single_label[i] ==1]
    elif args.setting==3:
        if args.dataset == 'syn2':
            allnodes = [i for i in range(single_label.shape[0]) if single_label[i] != 0 and single_label[i] != 4]
        else:
            allnodes = [i for i in range(single_label.shape[0]) if single_label[i] != 0]

    optimizer = torch.optim.Adam(explainer.parameters(), lr=args.elr)
    clip_value_min = -2.0
    clip_value_max = 2.0

    sub_support_tensors = []
    sub_label_tensors = []
    sub_features = []
    sub_embeds = []
    sub_adjs = []
    sub_edge_labels = []
    sub_labels = []
    remap = {}
    
    
    # CELL 2
    for node in allnodes:
        sub_adj,sub_feature, sub_embed, sub_label,sub_edge_label_matrix = extractor.subgraph(node)
        remap[node]=len(sub_adjs)
        sub_support = preprocess_adj(sub_adj)
        i = torch.LongTensor([*sub_support[0]])
        v = torch.FloatTensor([*sub_support[1]])
        # LET OP: i moet getransposed worden om sparse tensor te maken met pytorch
        sub_support_tensor = torch.sparse.FloatTensor(i.t(), v, torch.Size([*sub_support[2]])).type(torch.float32) 
        sub_label_tensor = torch.Tensor(sub_label).type(torch.float32)

        sub_adjs.append(sub_adj)
        sub_features.append(torch.Tensor(sub_feature).type(torch.float32))
        sub_embeds.append(sub_embed)
        sub_labels.append(sub_label)
        sub_edge_labels.append(sub_edge_label_matrix)
        sub_label_tensors.append(sub_label_tensor)
        sub_support_tensors.append(sub_support_tensor)
    best_auc = 0.0
    
    # TRAIN LOOP
    random.seed(iteration)
    np.random.seed(iteration)
    torch.manual_seed(iteration)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(iteration)
        torch.cuda.manual_seed_all(iteration)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    
#     f = open(save_map + str(iteration) + "/" + "LOG.txt", "w")
    train(iteration)

    explainer.load_state_dict(torch.load('model_weights/Tree-Cycles_BEST.pt'))
    
    reals= []
    preds = []

    tik = time.time()
    for node in allnodes:
        h = explain_test(node, iteration, needplot=False)
        auc = roc_auc_score(reals, preds)
#         f.write("node,{}".format(node) + ",auc,{}".format(auc) + "\n")
        print("node,{}".format(node) + ",auc,{}".format(auc))
    
    tok = time.time()
    print("time,{}".format(tok-tik))
    
#     f.write("time,{}".format(tok-tik) + "\n")
        
#     f.close()


# In[ ]:




