import random
import torch
import pandas as pd
import numpy as np
from torch import nn
from sklearn.metrics import confusion_matrix
from matplotlib import pyplot as plt
from sklearn.metrics import f1_score,accuracy_score
from torch.nn.parameter import Parameter
from tqdm import tqdm
import pickle
import matplotlib
import scipy.sparse as sp


num_copies = 100
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


def normalize(x):
    return np.dot(np.diag(x.sum(1)**-1),x)


def parse_data(content,cites):
    features = content.values[:,1:-1].astype(float)
    features = normalize(features)
    orin_ids = content[0]
    orin_label_types = set(content.values[:-1,-1])
    table = {}
    label_table = {}
    for index,id_ in enumerate(orin_ids):
        table[id_] = index
    for index,id_ in enumerate(orin_label_types):
        label_table[id_] = index
    label_ids = [label_table[i] for i in content.values[:,-1]]
    ids = [table[i] for i in content.values[:,0]]
    num_node = len(ids)
    edges = np.zeros((num_node,num_node))
    source_list = [table[i] for i in cites.to_dict('list')[1]]
    target_list = [table[i] for i in cites.to_dict('list')[0]]
    edges[target_list,source_list] = 1
    edges[source_list,target_list] = 1
    edges = normalize(edges+np.eye(len(edges)))
    return features,edges,label_ids


class Block(torch.nn.Module):
    def __init__(self,in_num,hidden_num,dropout_rate, init_std=1e-1, device=device):
        super().__init__()
        self.dropout = nn.Dropout(dropout_rate)
        self.weight = Parameter(torch.FloatTensor(in_num, hidden_num))
        self.noise_std = nn.Parameter(torch.full((hidden_num,), init_std, device=device))
        self.bias = Parameter(torch.FloatTensor(hidden_num))
        torch.nn.init.xavier_uniform_(self.weight.data, gain=1.41)
        self.bias.data.uniform_(-init_std, init_std)
        self.input_buf = None
        self.epsilon_1_buf = None
        self.epsilon_2_buf = None
        self.bias_noise_std = 1e-1
        self.device = device

    def encode(self,x, add_noise=True):
        self.input_cache = x
        logit_output = torch.mm(x,self.weight)
        if add_noise:
            self.input_buf = x
        return logit_output

    def aggregate(self,features,edges, add_noise=True):
        x = torch.mm(edges,features)
        if add_noise:
            self.input_buf = torch.mm(edges,self.input_buf)
        return x

    def forward(self,features,edges=None, add_noise=True):
        x = self.encode(features, add_noise=add_noise)
        x = self.aggregate(x,edges, add_noise)+self.bias 
        if add_noise:
            bs, out_features = x.shape
            epsilon = torch.zeros_like(x, device=self.device)
            epsilon[:bs//2] += torch.randn((bs//2, out_features), device=self.device)
            epsilon[bs//2:] -= epsilon[:bs//2]

            noise = epsilon
            self.epsilon_2_buf = epsilon 
            return x + noise
        return x
    
    def backward(self, loss):
        batch_size = self.input_buf.shape[0]
        loss = loss.unsqueeze(-1)
        if self.weight.grad is None:
            self.weight.grad = torch.zeros_like(self.weight, device=self.device)
            self.bias.grad = torch.zeros_like(self.bias, device=self.device)
            self.noise_std.grad = torch.zeros_like(self.noise_std, device=self.device)
        self.weight.grad += 1/num_copies * torch.einsum('ni,nj->ij', self.input_buf * loss, self.epsilon_2_buf) / (batch_size)
        self.bias.grad += 1/num_copies * torch.einsum('ni,nj->j', loss, self.epsilon_2_buf) / (batch_size)
        self.noise_std.grad += 1/num_copies * torch.einsum('ni,nj->j', loss, self.epsilon_2_buf ** 2 - 1) / (self.noise_std * batch_size)

        self.input_buf = None
        self.epsilon_buf = None


class GCN(torch.nn.Module):
    def __init__(self,in_num,hidden_num,out_num,dropout_rate, device=device):
        super().__init__()
        self.net_1 = Block(in_num,hidden_num,dropout_rate, device=device)
        self.net_2 = Block(hidden_num,out_num,dropout_rate, device=device)
        self.act = nn.LeakyReLU(0.01)
        self.device = device

    def forward(self,features,edges, add_noise=False):
        x = self.act(self.net_1(features,edges, add_noise=add_noise))
        x = self.net_2(x,edges, add_noise=add_noise)
        return x
    
    def backward(self, loss):
        self.net_1.backward(loss)
        self.net_2.backward(loss)


def read_data():
    content = pd.read_csv('cora/cora.content',header=None,sep='\t')
    cites = pd.read_csv('cora/cora.cites',header=None,sep='\t')
    features,edges,label_ids = parse_data(content,cites)
    indice = np.arange(features.shape[0])
    np.random.seed(0)
    np.random.shuffle(indice)
    train_num = int(round(len(indice)*0.7))
    test_num = int(round(len(indice)*0.1))
    train_indice = indice[:train_num]
    val_indice = indice[train_num:-test_num]
    test_indice = indice[-test_num:]
    features = torch.FloatTensor(features)
    edges = torch.FloatTensor(edges)
    label_ids = torch.LongTensor(label_ids)
    return features,edges,label_ids,train_indice,val_indice,test_indice


def dice(ori_adj, labels, n_perturbations):
    print('number of pertubations: %s' % n_perturbations)
    modified_adj = torch.from_numpy(ori_adj.numpy().copy())
    remove_or_insert = np.random.choice(2, n_perturbations)
    n_remove = sum(remove_or_insert)
    nonzero = set(zip(*ori_adj.nonzero()))
    indices = sp.triu(modified_adj).nonzero()
    possible_indices = [x for x in zip(indices[0], indices[1])
                        if labels[x[0]] == labels[x[1]]]
    remove_indices = np.random.permutation(possible_indices)[: n_remove]
    modified_adj[remove_indices[:, 0], remove_indices[:, 1]] = 0
    modified_adj[remove_indices[:, 1], remove_indices[:, 0]] = 0
    n_insert = n_perturbations - n_remove
    # sample edges to add
    added_edges = 0
    while added_edges < n_insert:
        n_remaining = n_insert - added_edges
        # sample random pairs
        candidate_edges = np.array([np.random.choice(ori_adj.shape[0], n_remaining),
                                    np.random.choice(ori_adj.shape[0], n_remaining)]).T
        # filter out existing edges, and pairs with the different labels
        candidate_edges = set([(u, v) for u, v in candidate_edges if labels[u] != labels[v]
                                    and modified_adj[u, v] == 0 and modified_adj[v, u] == 0])
        candidate_edges = np.array(list(candidate_edges))
        # if none is found, try again
        if len(candidate_edges) == 0:
            continue
        # add all found edges to your modified adjacency matrix
        modified_adj[candidate_edges[:, 0], candidate_edges[:, 1]] = 1
        modified_adj[candidate_edges[:, 1], candidate_edges[:, 0]] = 1
        added_edges += candidate_edges.shape[0]
    return modified_adj


def random_sample_edges(adj, n, exclude):
    itr = sample_forever(adj, exclude=exclude)
    return [next(itr) for _ in range(n)]


def sample_forever(adj, exclude):
    """Randomly random sample edges from adjacency matrix, `exclude` is a set
    which contains the edges we do not want to sample and the ones already sampled
    """
    while True:
        t = tuple(np.random.choice(adj.shape[0], 2, replace=False))
        if t not in exclude:
            yield t
            exclude.add(t)
            exclude.add((t[1], t[0]))


def random(adj, n_perturbations, type='add'):
    modified_adj = torch.from_numpy(adj.cpu().numpy().copy())
    
    type = type.lower()
    if type == 'flip':
        # sample edges to flip
        edges = random_sample_edges(adj, n_perturbations, exclude=set())
        for n1, n2 in edges:
            modified_adj[n1, n2] = 1 - modified_adj[n1, n2]
            modified_adj[n2, n1] = 1 - modified_adj[n2, n1]
    if type == 'add':
        # sample edges to add
        nonzero = set(zip(*adj.nonzero()))
        edges = random_sample_edges(adj, n_perturbations, exclude=nonzero)
        for n1, n2 in edges:
            modified_adj[n1, n2] = 1
            modified_adj[n2, n1] = 1
    if type == 'remove':
        # sample edges to remove
        nonzero = np.array(sp.triu(adj, k=1).nonzero()).T
        indices = np.random.permutation(nonzero)[: n_perturbations].T
        modified_adj[indices[0], indices[1]] = 0
        modified_adj[indices[1], indices[0]] = 0
    return modified_adj


def main():
    features,edges,label_ids,train_indice,val_indice,test_indice = read_data()
    net = GCN(features.shape[1],32,7,0.1)
    net = net.to(device)
    loss_func = torch.nn.CrossEntropyLoss(reduction='none')
    optim = torch.optim.Adam(net.parameters(),1e-3)
    train_losses = []
    val_losses = []
    train_f1s = []
    val_f1s = []
    best_val_loss = None
    counter = 0
    train_label_ids = label_ids[train_indice]
    val_label_ids = label_ids[val_indice]
    train_features = features[train_indice]
    train_edges = edges[train_indice][:, train_indice]
    
    train_features = train_features.to(device)
    train_edges = train_edges.to(device)
    train_label_ids = train_label_ids.to(device)
    val_label_ids = val_label_ids.to(device)
    features = features.to(device)
    edges = edges.to(device)
    label_ids = label_ids.to(device)
    n_iters = 500
    for e in range(50):
        with tqdm(range(1, n_iters + 1)) as tqdm_range:
            for iter in tqdm_range:
                net.train(True)
                optim.zero_grad()

                copy_train_loss = 0.
                for _ in range(num_copies):
                    output = net(train_features,train_edges, add_noise=True)
                    train_loss = loss_func(output,train_label_ids)
                    copy_train_loss += train_loss.mean().cpu().detach().numpy()
                    net.backward(train_loss)
                copy_train_loss /= num_copies
                optim.step()
                train_losses.append(copy_train_loss)
                net.train(False)
                output_val = net(features[val_indice],edges[val_indice][:, val_indice], add_noise=False)
                val_loss = loss_func(output_val,val_label_ids).mean()
                output_val = output_val.cpu().detach().numpy()
                val_loss_value = val_loss.cpu().detach().numpy()
                val_losses.append(val_loss_value)
                train_f1 = f1_score(train_label_ids.detach().cpu().numpy(),output.detach().cpu().numpy().argmax(1),average='macro')
                val_f1 = f1_score(val_label_ids.detach().cpu().numpy(),output_val.argmax(1),average='macro')
                train_f1s.append(train_f1)
                val_f1s.append(val_f1)
                if best_val_loss is None or best_val_loss >= val_loss_value:
                    best_val_loss = val_loss_value
                    counter = 0
                tqdm_range.set_description(
                    'Epoch: {}, Train Loss: {:.4f}, Val Loss: {:.4f}, Train F1: {:.4f}, Val F1: {:.4f}'.format(
                        e, copy_train_loss, val_loss_value, train_f1, val_f1))
    output_test = net(features[test_indice],edges[test_indice][:, test_indice], add_noise=False).detach().cpu()
    test_f1 = f1_score(label_ids[test_indice].detach().cpu().numpy(),output_test.argmax(1),average='macro')
    test_loss = loss_func(output_test,label_ids[test_indice].detach().cpu())
    confusion_matrix(label_ids[test_indice].detach().cpu().numpy(),output_test.detach().cpu().numpy().argmax(1))
    metrics = {
        "train_f1":train_f1s,"val_f1":val_f1s,"test_f1":test_f1,
        "train_loss":train_losses,"val_loss":val_losses,"test_loss":test_loss
    }
    val_features = features[val_indice]
    val_edges = edges[val_indice][:, val_indice]
    val_label_ids = label_ids[val_indice]
    add_acc = 0.
    flip_edge_acc = 0.
    dice_acc = 0.
    remove_edge_acc = 0.
    clean_acc = 0.
    for r in range(1, 11):
        add_edge = random(val_edges.cpu(),int(r/10*torch.sum(val_edges)),'add').to(device)
        flip_edge = random(val_edges.cpu(),int(r/10*torch.sum(val_edges)),'flip').to(device)
        remove_edge = random(val_edges.cpu(),int(r/10*torch.sum(val_edges)),'remove').to(device)
        # dice attack
        dice_edge = dice(val_edges.cpu(),val_label_ids.cpu(),int(r/10*torch.sum(val_edges))).to(device)
        add_output = net(val_features,add_edge)
        flip_edge_output = net(val_features,flip_edge)
        remove_edge_output = net(val_features,remove_edge)
        dice_output = net(val_features,dice_edge)
        clean_output = net(val_features,val_edges)
        add_acc += accuracy_score(add_output.cpu().detach().numpy().argmax(1),val_label_ids.cpu().detach().numpy())
        flip_edge_acc += accuracy_score(flip_edge_output.cpu().detach().numpy().argmax(1),val_label_ids.cpu().detach().numpy())
        dice_acc += accuracy_score(dice_output.cpu().detach().numpy().argmax(1),val_label_ids.cpu().detach().numpy())
        remove_edge_acc += accuracy_score(remove_edge_output.cpu().detach().numpy().argmax(1),val_label_ids.cpu().detach().numpy())
        clean_acc += accuracy_score(clean_output.cpu().detach().numpy().argmax(1),val_label_ids.cpu().detach().numpy())
    print('clean acc:',clean_acc/10)
    print('add acc:',add_acc/10)
    print('flip acc:',flip_edge_acc/10)
    print('dice acc:',dice_acc/10)
    print('remove acc:',remove_edge_acc/10)
    return net,metrics


if __name__ =='__main__':
    net,metrics = main()
    with open('result/lr_metrics.pkl','wb') as f:
        pickle.dump([net, metrics],f)

    matplotlib.use('Agg')
    
    _,(ax1,ax2) = plt.subplots(1,2,figsize=(10,5))
    ax1.plot(metrics['train_f1'],label='train')
    ax1.plot(metrics['val_f1'],label='val')
    ax1.legend()
    ax1.set_title('f1')
    
    ax2.plot(metrics['train_loss'],label='train')
    ax2.plot(metrics['val_loss'],label='val')
    ax2.legend()
    ax2.set_title('loss')
    
    plt.savefig('result/lr_metrics.png')
    torch.save(net.state_dict(),'result/llr_gcn_model.pth')