#!/usr/bin/env python
# coding: utf-8

# In[ ]:


import sys
sys.path.append('./codes/forgraph/')
from config import args
from sklearn.metrics import roc_auc_score
from models import GCN2 as GCN
from metrics import *
import numpy as np
from Explainer import Explainer
from scipy.sparse import coo_matrix,csr_matrix
import networkx as nx
skip = 5
topk = 5
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import pickle as pkl
import torch 
import torch.optim

import time
save_map = "LISA_TEST_LOGS/BA_2MOTIF/"

args.elr = 0.00015
args.coff_t0=5.0
args.coff_t0=0.5
args.coff_size = 0.01
args.coff_ent = 0
# args.concat = True
# args.bn = True


# In[ ]:


def acc(adj,insert):
    mask = explainer.masked_adj.cpu().detach().numpy()
    adj = coo_matrix(adj)
    for r,c in list(zip(adj.row,adj.col)):
        if r>=insert and r<insert+skip and c>=insert and c<insert+skip:
            reals.append(1)
        else:
            reals.append(0)
        preds.append(mask[r][c])


def plot(adj,label,graphid, iteration):
    after_adj_dense = explainer.masked_adj.cpu().detach().numpy()

    after_adj = coo_matrix(after_adj_dense)

    rcd = np.concatenate(
        [np.expand_dims(after_adj.row, -1), np.expand_dims(after_adj.col, -1), np.expand_dims(after_adj.data, -1)], -1)
    pos_edges = []
    filter_edges = []

    edge_weights = np.triu(after_adj_dense).flatten()

    sorted_edge_weights = np.sort(edge_weights)
    thres_index = max(int(edge_weights.shape[0] - topk), 0)
    thres = sorted_edge_weights[thres_index]

    for r, c, d in rcd:
        if r<c:
            continue
        if d >= thres:
            pos_edges.append((r, c))
        filter_edges.append((r, c))

    G = nx.from_numpy_matrix(adj)
    pos = nx.kamada_kawai_layout(G)

    colors = ['orange', 'lime', 'red', 'blue', 'maroon', 'brown', 'darkslategray', 'paleturquoise', 'darksalmon',
              'slategray', 'mediumseagreen', 'mediumblue', 'orchid']

    # nodes

    nmb_nodes = after_adj_dense.shape[0]

    node_filter = []
    for node in range(nmb_nodes):
        if node in G.nodes():
            node_filter.append(node)

    nx.draw_networkx_nodes(G, pos,
                           nodelist=node_filter,
                           node_color=colors[0],
                           node_size=300)

    nx.draw_networkx_edges(G, pos, width=2, edge_color='grey')

    nx.draw_networkx_edges(G, pos,
                           edgelist=pos_edges,
                           width=7)

    plt.axis('off')
#     plt.show()
    plt.savefig(save_map + str(iteration) + "/" + str(graphid) + ".png")
    plt.clf()



def explain_graph(gid):
    fea,emb,adj,label,graphid = features[gid], embs[gid], torch.tensor(adjs[gid]), torch.tensor(labels[gid]), gid
    explainer.eval()
    
    explainer((fea,emb,adj,1.0,label))
    insert = 20
    acc(adj,insert)



def test():
    global preds
    global reals
    preds = []
    reals = []
    for gid in allnodes:
        explain_graph(gid)
    auc = roc_auc_score(reals,preds)
    return auc



def train(iteration):
    epochs = args.eepochs
    t0 = args.coff_t0
    t1 = args.coff_te
    best_auc = 0
    explainer.train()
    for epoch in range(epochs):
        loss = 0
        tmp = float(t0 * np.power(t1 / t0, epoch /epochs))
        train_instances = [ins for ins in range(adjs.shape[0])]
        np.random.shuffle(train_instances)
        for gid in train_instances:
            pred = explainer((torch.Tensor(features[gid]), embs[gid], torch.Tensor(adjs[gid]),tmp, torch.Tensor(labels[gid])))
            loss = loss + explainer.loss(pred, pred_label[gid])

        train_variables = []
        for name, para in explainer.named_parameters():
            if "elayers" in name:
                train_variables.append(para)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if epoch%1==0:
            auc = test()
            if auc > best_auc:
                best_auc = auc
                torch.save(explainer.state_dict(), 'model_weights/BA2motif_BESTAUC.pt')
                torch.save(explainer.state_dict(), save_map + str(iteration) + "/" + 'BA2motif_BESTAUC.pt')



for iteration in range(10):
    print("Starting iteration: {}".format(iteration))

    device = "cuda:0" if torch.cuda.is_available() else "cpu"

    #CELL 1
    with open('./dataset/BA-2motif.pkl','rb') as fin:
        adjs, features, labels = pkl.load(fin)

    model = GCN(input_dim=features.shape[1:], output_dim=labels.shape[1], device=device)
    model.to(device)

    if args.bn and args.concat:
        model.load_state_dict(torch.load('model_weights/GCN_BA2motif_bn_concat.pt'))
    elif args.bn: 
        model.load_state_dict(torch.load('model_weights/GCN_BA2motif_bn.pt'))
    elif args.concat:
        model.load_state_dict(torch.load('model_weights/GCN_BA2motif_concat.pt'))      
    else:
        model.load_state_dict(torch.load('model_weights/GCN_BA2motif_BEST.pt'))
    model.eval()

    with torch.no_grad():
        embs = model.getNodeEmb((torch.tensor(features).type(torch.float32), torch.tensor(adjs).type(torch.float32)), training=False)

        output = model((torch.tensor(features).type(torch.float32), torch.tensor(adjs).type(torch.float32)), training=False)
    pred_label = torch.argmax(output, 1)

    #CELL2
    if args.setting==1:
        allnodes = [i for i in range(0,100)]
    elif args.setting==2:
        allnodes = [i for i in range(0,100)]
        allnodes.extend([i for i in range(500,600)])
    elif args.setting==3:
        allnodes=[i for i in range(1000)]
    explainer = Explainer(model=model,nodesize=adjs.shape[1])
    explainer.to(device)
    optimizer = torch.optim.Adam(explainer.parameters(), lr=args.elr)
    
    
    
    # Training
    np.random.seed(iteration)
    torch.manual_seed(iteration)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(iteration)
        torch.cuda.manual_seed_all(iteration)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    
    f = open(save_map + str(iteration) + "/" + "LOG.txt", "w")

    train(iteration)

    explainer.load_state_dict(torch.load('model_weights/BA2motif_BESTAUC.pt'))
    
    global preds
    global reals
    preds = []
    reals = []

    tik = time.time()
    for gid in allnodes:
        explain_graph(gid)
        auc = roc_auc_score(reals,preds)
        tok = time.time()
        f.write("gid,{}".format(gid) + ",auc,{}".format(auc) + ",time,{}".format(tok-tik) + "\n")
    
    tok = time.time()
    f.write("time,{}".format(tok-tik) + "\n")
    
    for gid in allnodes:
        fea,emb,adj,label,graphid = features[gid], embs[gid], adjs[gid], torch.Tensor(labels[gid]), gid
        explainer((fea,emb,torch.Tensor(adj),1.0,label))
        plot(adj,label,graphid, iteration)

    f.close()

