import random
import torch
import pandas as pd
import numpy as np
from torch import nn
from sklearn.metrics import confusion_matrix
from matplotlib import pyplot as plt
from sklearn.metrics import f1_score, accuracy_score
from torch.nn.parameter import Parameter
from tqdm import tqdm
import torch.nn.functional as F
import os
import scipy.sparse as sp
import pickle
import matplotlib

num_copies = 1
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"


def normalize(x):
    return np.dot(np.diag(x.sum(1) ** -1), x)


def parse_data(content, cites):
    features = content.values[:, 1:-1].astype(float)
    features = normalize(features)
    orin_ids = content[0]
    orin_label_types = set(content.values[:-1, -1])
    table = {}
    label_table = {}
    for index, id_ in enumerate(orin_ids):
        table[id_] = index
    for index, id_ in enumerate(orin_label_types):
        label_table[id_] = index
    label_ids = [label_table[i] for i in content.values[:, -1]]
    ids = [table[i] for i in content.values[:, 0]]
    num_node = len(ids)
    edges = np.zeros((num_node, num_node))
    source_list = [table[i] for i in cites.to_dict('list')[1]]
    target_list = [table[i] for i in cites.to_dict('list')[0]]
    edges[target_list, source_list] = 1
    edges[source_list, target_list] = 1
    edges = normalize(edges + np.eye(len(edges)))
    return features, edges, label_ids


class GraphAttentionLayer(nn.Module):
    """
    Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
    """

    def __init__(self, in_features, out_features, dropout, alpha, concat=True, device=device):
        super(GraphAttentionLayer, self).__init__()
        self.dropout = dropout
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat
        self.device = device

        self.W = nn.Parameter(torch.zeros(size=(in_features, out_features), device=device))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        self.Q = nn.Parameter(torch.zeros(size=(in_features, out_features), device=device))
        nn.init.xavier_uniform_(self.Q.data, gain=1.414)
        self.V = nn.Parameter(torch.zeros(size=(in_features, out_features), device=device))
        nn.init.xavier_uniform_(self.V.data, gain=1.414)
        self.a = nn.Parameter(torch.zeros(size=(2 * out_features, 1), device=device))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)

        self.leakyrelu = nn.LeakyReLU(self.alpha)

        self.input_cache = None
        self.a_input_cache = None
        self.h_noise_cache = None
        self.q_noise_cache = None
        self.v_noise_cache = None
        self.a_noise_cache = None

    def generate_noise(self, x):
        bs, out_features = x.shape
        epsilon = torch.zeros_like(x, device=self.device)
        epsilon[:bs // 2] += torch.randn((bs // 2, out_features), device=self.device)
        epsilon[bs // 2:] -= epsilon[:bs // 2]
        return epsilon

    def forward(self, input, adj, add_noise=False):
        h = torch.mm(input, self.W)
        q = torch.mm(input, self.Q)
        v = torch.mm(input, self.V)
        if add_noise:
            h_noise = self.generate_noise(h)
            q_noise = self.generate_noise(q)
            v_noise = self.generate_noise(v)
            h = h + h_noise
            q = q + q_noise
            v = v + v_noise
            self.input_cache = input
            self.h_noise_cache = h_noise
            self.q_noise_cache = q_noise
            self.v_noise_cache = v_noise
        N = h.size()[0]

        a_input = torch.cat([h.repeat(1, N).view(N * N, -1), q.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features)
        aa = torch.matmul(a_input, self.a).squeeze(2)

        if add_noise:
            self.a_noise_cache = self.generate_noise(aa)
            aa = aa + self.a_noise_cache
            self.a_input_cache = a_input.reshape(-1, 2 * self.out_features)
        e = self.leakyrelu(aa)

        zero_vec = -9e15 * torch.ones_like(e, device=self.device)
        attention = torch.where(adj > 0, e, zero_vec)
        attention = F.softmax(attention, dim=1)
        
        h_prime = torch.matmul(attention, v)

        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime

    def backward(self, loss):
        
        N = len(self.input_cache)
        loss = loss.unsqueeze(-1)
        
        if self.W.grad is None:
            self.W.grad = torch.zeros_like(self.W, device=self.device)
            self.Q.grad = torch.zeros_like(self.Q, device=self.device)
            self.V.grad = torch.zeros_like(self.V, device=self.device)
            self.a.grad = torch.zeros_like(self.a, device=self.device)
        self.W.grad += 1 / num_copies * torch.einsum('ni,nj->ij', self.input_cache * loss, self.h_noise_cache) / (
            len(self.input_cache))
        self.Q.grad += 1 / num_copies * torch.einsum('ni,nj->ij', self.input_cache * loss, self.q_noise_cache) / (
            len(self.input_cache))
        self.V.grad += 1 / num_copies * torch.einsum('ni,nj->ij', self.input_cache * loss, self.v_noise_cache) / (
            len(self.input_cache))
        
        self.a.grad += 1 / num_copies * torch.einsum('abi,ab->i',
                                                     self.a_input_cache.reshape(N, N, -1) * loss.reshape(N, 1, 1).repeat(1, N, 1),
                                                     self.a_noise_cache).reshape(-1, 1) / (len(self.a_input_cache))

        self.input_buf = None
        self.epsilon_buf = None

    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'


class GAT(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout, alpha=0.01, nheads=4):
        """Dense version of GAT."""
        super(GAT, self).__init__()
        self.dropout = dropout

        self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in
                           range(nheads)]
        for i, attention in enumerate(self.attentions):
            self.add_module('attention_{}'.format(i), attention)

        self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False)

    def forward(self, x, adj, add_noise=False):
        x = torch.cat([att(x, adj, add_noise) for att in self.attentions], dim=1)
        x = F.elu(self.out_att(x, adj, add_noise))
        return F.log_softmax(x, dim=1)

    def backward(self, loss):
        for att_layer in self.attentions:
            att_layer.backward(loss)
        self.out_att.backward(loss)


def dice(ori_adj, labels, n_perturbations):
    print('number of pertubations: %s' % n_perturbations)
    modified_adj = torch.from_numpy(ori_adj.numpy().copy())
    remove_or_insert = np.random.choice(2, n_perturbations)
    n_remove = sum(remove_or_insert)
    nonzero = set(zip(*ori_adj.nonzero()))
    indices = sp.triu(modified_adj).nonzero()
    possible_indices = [x for x in zip(indices[0], indices[1])
                        if labels[x[0]] == labels[x[1]]]
    remove_indices = np.random.permutation(possible_indices)[: n_remove]
    modified_adj[remove_indices[:, 0], remove_indices[:, 1]] = 0
    modified_adj[remove_indices[:, 1], remove_indices[:, 0]] = 0
    n_insert = n_perturbations - n_remove
    # sample edges to add
    added_edges = 0
    while added_edges < n_insert:
        n_remaining = n_insert - added_edges
        # sample random pairs
        candidate_edges = np.array([np.random.choice(ori_adj.shape[0], n_remaining),
                                    np.random.choice(ori_adj.shape[0], n_remaining)]).T
        # filter out existing edges, and pairs with the different labels
        candidate_edges = set([(u, v) for u, v in candidate_edges if labels[u] != labels[v]
                               and modified_adj[u, v] == 0 and modified_adj[v, u] == 0])
        candidate_edges = np.array(list(candidate_edges))
        # if none is found, try again
        if len(candidate_edges) == 0:
            continue
        # add all found edges to your modified adjacency matrix
        modified_adj[candidate_edges[:, 0], candidate_edges[:, 1]] = 1
        modified_adj[candidate_edges[:, 1], candidate_edges[:, 0]] = 1
        added_edges += candidate_edges.shape[0]
    return modified_adj


def random_sample_edges(adj, n, exclude):
    itr = sample_forever(adj, exclude=exclude)
    return [next(itr) for _ in range(n)]


def sample_forever(adj, exclude):
    """Randomly random sample edges from adjacency matrix, `exclude` is a set
    which contains the edges we do not want to sample and the ones already sampled
    """
    while True:
        t = tuple(np.random.choice(adj.shape[0], 2, replace=False))
        if t not in exclude:
            yield t
            exclude.add(t)
            exclude.add((t[1], t[0]))


def random(adj, n_perturbations, type='add'):
    modified_adj = torch.from_numpy(adj.cpu().numpy().copy())

    type = type.lower()
    if type == 'flip':
        # sample edges to flip
        edges = random_sample_edges(adj, n_perturbations, exclude=set())
        for n1, n2 in edges:
            modified_adj[n1, n2] = 1 - modified_adj[n1, n2]
            modified_adj[n2, n1] = 1 - modified_adj[n2, n1]
    if type == 'add':
        # sample edges to add
        nonzero = set(zip(*adj.nonzero()))
        edges = random_sample_edges(adj, n_perturbations, exclude=nonzero)
        for n1, n2 in edges:
            modified_adj[n1, n2] = 1
            modified_adj[n2, n1] = 1
    if type == 'remove':
        # sample edges to remove
        nonzero = np.array(sp.triu(adj, k=1).nonzero()).T
        indices = np.random.permutation(nonzero)[: n_perturbations].T
        modified_adj[indices[0], indices[1]] = 0
        modified_adj[indices[1], indices[0]] = 0
    return modified_adj


def read_data():
    content = pd.read_csv('cora/cora.content', header=None, sep='\t')
    cites = pd.read_csv('cora/cora.cites', header=None, sep='\t')
    features, edges, label_ids = parse_data(content, cites)
    indice = np.arange(features.shape[0])
    np.random.seed(0)
    np.random.shuffle(indice)
    train_num = int(round(len(indice) * 0.7))
    test_num = int(round(len(indice) * 0.1))
    train_indice = indice[:train_num]
    val_indice = indice[train_num:-test_num]
    test_indice = indice[-test_num:]
    features = torch.FloatTensor(features)
    edges = torch.FloatTensor(edges)
    label_ids = torch.LongTensor(label_ids)
    return features, edges, label_ids, train_indice, val_indice, test_indice


def main():
    features, edges, label_ids, train_indice, val_indice, test_indice = read_data()
    net = GAT(features.shape[1], 32, 7, 0.1)
    loss_func = torch.nn.CrossEntropyLoss(reduction='none')
    optim = torch.optim.Adam(net.parameters(), 1e-3)
    train_losses = []
    val_losses = []
    train_f1s = []
    val_f1s = []
    best_val_loss = None
    counter = 0
    train_label_ids = label_ids[train_indice]
    val_label_ids = label_ids[val_indice]
    train_features = features[train_indice]
    train_edges = edges[train_indice][:, train_indice]
    net = net.to(device)
    train_features = train_features.to(device)
    train_edges = train_edges.to(device)
    train_label_ids = train_label_ids.to(device)
    val_label_ids = val_label_ids.to(device)
    label_ids = label_ids.to(device)
    features = features.to(device)
    edges = edges.to(device)
    n_iters = 500
    for e in range(50):
        with tqdm(range(1, n_iters + 1)) as tqdm_range:
            for iter in tqdm_range:
                net.train(True)
                optim.zero_grad()
                copy_train_loss = 0.
                for _ in range(num_copies):
                    output = net(train_features, train_edges, add_noise=True)
                    train_loss = loss_func(output, train_label_ids)
                    copy_train_loss += train_loss.mean().cpu().detach().numpy()
                    net.backward(train_loss)
                copy_train_loss /= num_copies
                optim.step()
                train_losses.append(copy_train_loss)
                net.train(False)
                output_val = net(features[val_indice], edges[val_indice][:, val_indice], add_noise=False)
                val_loss = loss_func(output_val, val_label_ids).mean()
                val_loss_value = val_loss.cpu().detach().numpy()
                val_losses.append(val_loss_value)
                train_f1 = f1_score(train_label_ids.cpu(), output.detach().cpu().numpy().argmax(1), average='macro')
                val_f1 = f1_score(val_label_ids.cpu(), output_val.detach().cpu().numpy().argmax(1), average='macro')
                train_f1s.append(train_f1)
                val_f1s.append(val_f1)
                if best_val_loss is None or best_val_loss >= val_loss_value:
                    best_val_loss = val_loss_value
                    counter = 0
                tqdm_range.set_description(
                    'Epoch: {}, Train Loss: {:.4f}, Val Loss: {:.4f}, Train F1: {:.4f}, Val F1: {:.4f}'.format(
                        e, copy_train_loss, val_loss_value, train_f1, val_f1))
    output_test = net(features[test_indice], edges[test_indice][:, test_indice], add_noise=False)
    test_f1 = f1_score(label_ids[test_indice].cpu(), output_test.detach().cpu().numpy().argmax(1), average='macro')
    test_loss = loss_func(output_test, label_ids[test_indice])
    confusion_matrix(label_ids[test_indice].cpu(), output_test.detach().cpu().numpy().argmax(1))
    metrics = {
        "train_f1": train_f1s, "val_f1": val_f1s, "test_f1": test_f1,
        "train_loss": train_losses, "val_loss": val_losses, "test_loss": test_loss
    }
    val_features = features[val_indice]
    val_edges = edges[val_indice][:, val_indice]
    val_label_ids = label_ids[val_indice]
    add_acc = 0.
    flip_edge_acc = 0.
    dice_acc = 0.
    remove_edge_acc = 0.
    clean_acc = 0.
    for r in range(1, 11):
        add_edge = random(val_edges.cpu(), int(r / 10 * torch.sum(val_edges)), 'add').to(device)
        flip_edge = random(val_edges.cpu(), int(r / 10 * torch.sum(val_edges)), 'flip').to(device)
        remove_edge = random(val_edges.cpu(), int(r / 10 * torch.sum(val_edges)), 'remove').to(device)
        # dice attack
        dice_edge = dice(val_edges.cpu(), val_label_ids.cpu(), int(r / 10 * torch.sum(val_edges))).to(device)
        add_output = net(val_features, add_edge)
        flip_edge_output = net(val_features, flip_edge)
        remove_edge_output = net(val_features, remove_edge)
        dice_output = net(val_features, dice_edge)
        clean_output = net(val_features, val_edges)
        add_acc += accuracy_score(add_output.cpu().detach().numpy().argmax(1), val_label_ids.cpu().detach().numpy())
        flip_edge_acc += accuracy_score(flip_edge_output.cpu().detach().numpy().argmax(1),
                                        val_label_ids.cpu().detach().numpy())
        dice_acc += accuracy_score(dice_output.cpu().detach().numpy().argmax(1), val_label_ids.cpu().detach().numpy())
        remove_edge_acc += accuracy_score(remove_edge_output.cpu().detach().numpy().argmax(1),
                                          val_label_ids.cpu().detach().numpy())
        clean_acc += accuracy_score(clean_output.cpu().detach().numpy().argmax(1), val_label_ids.cpu().detach().numpy())
    print('clean acc:', clean_acc / 10)
    print('add acc:', add_acc / 10)
    print('flip acc:', flip_edge_acc / 10)
    print('dice acc:', dice_acc / 10)
    print('remove acc:', remove_edge_acc / 10)
    return net, metrics


if __name__ == '__main__':
    net, metrics = main()
    with open('result/lr_gat_metrics.pkl', 'wb') as f:
        pickle.dump([net, metrics], f)

    matplotlib.use('Agg')

    
    _, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
    ax1.plot(metrics['train_f1'], label='train')
    ax1.plot(metrics['val_f1'], label='val')
    
    ax1.legend()
    ax1.set_title('f1')

    ax2.plot(metrics['train_loss'], label='train')
    ax2.plot(metrics['val_loss'], label='val')
    
    ax2.legend()
    ax2.set_title('loss')

    plt.savefig('result/lr_gat_metrics.png')
    torch.save(net.state_dict(), 'result/llr_gat_model.pth')