import sys
sys.path.append('./codes/forgraph/')
# sys.path.append('codes/forgraph/')
from config import args
import matplotlib.pyplot as plt
from utils import get_graph_data
from models import GCN2 as GCN
from metrics import *
import numpy as np
from Explainer import Explainer
from scipy.sparse import coo_matrix
import networkx as nx
from sklearn.metrics import roc_auc_score
import pickle as pkl
import torch
import gc

import time
save_map = "LISA_TEST_LOGS/MUTAG/"

clip_value_min = -2.0
clip_value_max = 2.0

args.dataset = "Mutagenicity"
args.coff_t0 = 5.0 # For this dataset, high temperature works well.
args.coff_te = 5.0
args.elr = 0.0003
args.coff_size = 0.005
args.coff_ent = 1.0
# FUNCTIONS
args.eepochs = 2
def acc(explainer, gid):
    mask = explainer.masked_adj.cpu().detach().numpy()
    edge_labels = selected_edge_label_lists[gid]
    edge_list = selected_edge_lists[gid]
    for (r,c),l in list(zip(edge_list,edge_labels)):
        if r > c:
            reals.append(l)
            preds.append(mask[r][c])


def explain_graph(fea,emb,adj,label,graphid,iteration, needplot=True, topk=4):
    explainer.eval()
    explainer((fea,emb,adj,1.0,label))
    acc(explainer,graphid)
    if not needplot:
        return
    after_adj_dense = explainer.masked_adj.cpu().detach().numpy()
    after_adj = coo_matrix(after_adj_dense)
    rcd = np.concatenate(
        [np.expand_dims(after_adj.row, -1), np.expand_dims(after_adj.col, -1), np.expand_dims(after_adj.data, -1)], -1)
    pos_edges = []
    filter_edges = []
    edge_weights = np.triu(after_adj_dense).flatten()

    sorted_edge_weights = np.sort(edge_weights)
    thres_index = max(int(edge_weights.shape[0] - topk), 0)
    thres = sorted_edge_weights[thres_index]

    for r, c, d in rcd:
        r = int(r)
        c = int(c)
        d = float(d)
        if r < c:
            continue
        if d >= thres:
            pos_edges.append((r, c))
        filter_edges.append((r, c))

    node_label = selected_node_label_lists[graphid]
    max_label = np.max(node_label) + 1
    nmb_nodes = len(node_label)


    G = nx.Graph()
    G.add_nodes_from(range(nmb_nodes))
    G.add_edges_from(filter_edges)

    pos_edges = [(u, v) for (u, v) in pos_edges if u in G.nodes() and v in G.nodes()]
    pos = nx.kamada_kawai_layout(G)


    colors = ['orange','red','lime','green','blue','orchid','darksalmon','darkslategray','gold','bisque','tan','lightseagreen','indigo','navy']

    label2nodes = []
    for i in range(max_label):
        label2nodes.append([])
    for i in range(nmb_nodes):
        if i in G.nodes():
            label2nodes[node_label[i]].append(i)

    for i in range(max_label):
        node_filter = []
        for j in range(len(label2nodes[i])):
            node_filter.append(label2nodes[i][j])
        nx.draw_networkx_nodes(G, pos,
                               nodelist=node_filter,
                               node_color=colors[i],
                               node_size=300)

    nx.draw_networkx_edges(G, pos, width=2,  edge_color='grey')

    nx.draw_networkx_edges(G, pos,
                           edgelist=pos_edges,
                           width=7)

    plt.title('Graph: '+str(graphid)+' label: '+str(selected_graph_labels[graphid]))
    plt.axis('off')
    # plt.show()
    plt.savefig(save_map + str(iteration) + "/" + str(graphid) + ".png")
    plt.clf()


# for this dataset, small batch when explanining works better. Here we simply set batch size = 1
def train(iteration):
    epochs = args.eepochs
    t0 = args.coff_t0
    t1 = args.coff_te
    best_auc = 0
    explainer.train()
    for epoch in range(epochs):
        loss = 0
        tmp = t0 * np.power(t1 / t0, epoch /epochs )
        for gid in range(selected_adjs.shape[0]):
            pred = explainer.forward((selected_features_tensor[gid],selected_embs[gid],\
                              selected_adjs_tensor[gid],tmp, selected_labels_tensor[gid]),training=True)
            cl = explainer.loss(pred, selected_pred_label[gid])
            loss = loss + cl
            if gid%1 == 0:
                optimizer.zero_grad()
                loss.backward()
                final_loss = loss.cpu().detach()
                loss = loss.detach()
                torch.nn.utils.clip_grad_value_(explainer.parameters(), clip_value_max)
                optimizer.step()
                gc.collect()
        # train_variables = [para for para in explainer.trainable_variables
        #                if para.name.startswith('explainer')]
        # grads = tape.gradient(loss, train_variables)
        # cliped_grads = [tf.clip_by_value(t, clip_value_min, clip_value_max) for t in grads]
        # optimizer.apply_gradients(zip(cliped_grads, train_variables))

        if epoch%1==0:
            print('epoch',epoch,'loss',final_loss.numpy())
            global reals
            global preds
            reals = []
            preds = []
            for gid in range(int(selected_adjs.shape[0]/10)):
                fea, emb, adj, label = selected_features_tensor[gid], selected_embs[gid], \
                                       selected_adjs_tensor[gid], selected_labels_tensor[gid]
                explain_graph(fea, emb, adj, label, gid, iteration, needplot=False)

            auc = roc_auc_score(reals, preds)
            print(auc)
            if auc > best_auc:
                torch.save(explainer.state_dict(), 'model_weights/MUTAG_BESTAUC.pt')
                torch.save(explainer.state_dict(), save_map + str(iteration) + "/" + 'MUTAG_BESTAUC.pt')

# TRAIN LOOP

for iteration in range(10):
    edge_lists, graph_labels, edge_label_lists, node_label_lists = get_graph_data(args.dataset)
    print('********** opening dataset *************** \n')
    # with open('./dataset/Mutagenicity.pkl','rb') as fin:
    #     original_adjs,original_features,original_labels = pkl.load(fin)
    original_adjs = np.load('./dataset/mutag_chunks/sub_adjs_1.npy')
    original_features = np.load('./dataset/mutag_chunks/sub_feas_1.npy')
    original_labels = np.load('./dataset/mutag_chunks/sub_labels_1.npy')
    print('********** finished opening dataset *************** \n')

    # we only consider the mutagen graphs with NO2 and NH2.
    selected =  []
    for gid in range(original_adjs.shape[0]):
        if np.argmax(original_labels[gid]) == 0 and np.sum(edge_label_lists[gid]) > 0:
            selected.append(gid)
    print('number of mutagen graphs with NO2 and NH2',len(selected))
    selected_adjs = original_adjs[selected]
    selected_features = original_features[selected]
    selected_labels = original_labels[selected]
    selected_edge_lists = [edge_lists[i] for i in selected]
    selected_graph_labels=graph_labels[selected]
    selected_edge_label_lists=[edge_label_lists[i] for i in selected]
    selected_node_label_lists=[node_label_lists[i] for i in selected]
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    # device = 'cpu'
    model = GCN(input_dim=selected_features.shape[1:], output_dim=selected_labels.shape[1], device=device)
    model.to(device)
    model.load_state_dict(torch.load('./model_weights/GCN_Mutagenicity_BEST.pt'))
    model.eval()

    selected_adjs_tensor = torch.tensor(selected_adjs,dtype=torch.float32)
    selected_features_tensor = torch.tensor(selected_features,dtype=torch.float32)
    selected_labels_tensor = torch.tensor(selected_labels,dtype=torch.float32)
    selected_output = model.forward((selected_features_tensor,selected_adjs_tensor),training=False)
    selected_acc = accuracy(selected_output, selected_labels_tensor)
    selected_pred_label = torch.argmax(selected_output, 1)
    with torch.no_grad():
        selected_embs = model.getNodeEmb((selected_features_tensor, selected_adjs_tensor), training=False)

    explainer = Explainer(model=model,nodesize=selected_adjs.shape[1])
    explainer.to(device)
    optimizer = torch.optim.Adam(explainer.parameters(), lr=args.elr)

    # TRAINING
    np.random.seed(iteration)
    torch.manual_seed(iteration)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(iteration)
        torch.cuda.manual_seed_all(iteration)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    f = open(save_map + str(iteration) + "/" + "LOG.txt", "w")
    train(iteration)

    explainer.load_state_dict(torch.load('model_weights/MUTAG_BESTAUC.pt'))

    global preds
    global reals
    reals = []
    preds = []

    tik = time.time()
    for gid in range(selected_adjs.shape[0]):
        fea, emb, adj, label = selected_features[gid], selected_embs[gid], selected_adjs[gid], selected_labels[gid]
        explain_graph(fea,emb,adj,label,gid, iteration,needplot=False)
        auc = roc_auc_score(reals,preds)
        tok = time.time()
        f.write("gid,{}".format(gid) + ",auc,{}".format(auc) + ",time,{}".format(tok-tik) + "\n")

    tok = time.time()
    f.write("time,{}".format(tok-tik) + "\n")

    for gid in range(selected_adjs.shape[0]):
        fea, emb, adj, label = selected_features[gid], selected_embs[gid], selected_adjs[gid], selected_labels[gid]
        explain_graph(fea,emb,adj,label,gid, iteration,needplot=True)

    f.close()
