import os
import copy
import warnings
import time
import math
import numpy as np
import argparse
import seaborn as sns

import torch
import torch.nn as nn
from torch.nn import Parameter
import torch.nn.functional as F
from torch_geometric.utils import to_undirected
from torch_geometric.nn import MessagePassing
from torch_geometric.nn.conv.gcn_conv import gcn_norm

from models import *
from utils import set_seed
from dataset_loader import HeterophilousGraphDataset
from eval import unsupervised_eval_linear
warnings.filterwarnings("ignore")


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42, help='seed.')
    parser.add_argument('--dataset', type=str,default='Cora')
    parser.add_argument('--device', type=int, default=0, help='GPU device.')
    parser.add_argument('--runs', type=int, default=10, help='number of runs.')
    parser.add_argument('--net', type=str, default='GCN')

    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--hidden', type=int, default=512, help='hidden units.')
    parser.add_argument('--dropout', type=float, default=0.5, help='dropout for neural networks.')
    
    parser.add_argument('--train_rate', type=float, default=0.6, help='train set rate.')
    parser.add_argument('--val_rate', type=float, default=0.2, help='val set rate.')

    # unsupervised learning
    parser.add_argument("--unsup_epochs", type=int, default=10000, help="Unupservised training epochs.")
    parser.add_argument("--patience", type=int, default=50, help="Patient epochs to wait before early stopping.")
    parser.add_argument("--lr1", type=float, default=0.001, help="Learning rate of the unsupervised model.")
    parser.add_argument("--lr2", type=float, default=0.01, help="Learning rate of linear evaluator.")
    parser.add_argument("--wd1", type=float, default=0.0, help="Weight decay of the unsupervised model.")
    parser.add_argument("--wd2", type=float, default=0.0, help="Weight decay of linear evaluator.")

    parser.add_argument('--mm', type=float, default=0.8)
    parser.add_argument('--lr_gamma', type=float, default=0.005)
    parser.add_argument('--output_size', type=int, default=512)
    parser.add_argument('--feat_drop', type=float, default=0.2)
    parser.add_argument('--edge_drop', type=float, default=0.2)
    parser.add_argument('--K', type=int, default=3, help="number of layer in hierarchial_n2n loss")
    parser.add_argument('--alpha', type=float, default=0.9)
    parser.add_argument('--scale', type=float, default=0.01, help='factor that scales the loss of CCA')
    parser.add_argument('--k', type=int, default=8, help="number of neighbors of knn augmentation")
    parser.add_argument('--activation_learner', type=str, default='relu', choices=["relu", "tanh"])
    parser.add_argument('--sparse', action='store_true', default=False)
    parser.add_argument('--knn_metric', type=str,  default='cosine', choices=['cosine', 'minkowski'])
    parser.add_argument('--Init', type=str, default='random', help='the init method gamma logits')
    parser.add_argument('-ta', '--topology_augmentation', type=str, default='learned', choices=['learned', 'knn', 'init', 'drop'])
    args = parser.parse_args()
    return args


def CCA_SSG(z1, z2, beta=0.1):
    device = z1.device
    N = z1.size(0)
    D = z1.size(1)
    z1_norm = ((z1-z1.mean(0)) / z1.std(0)) / math.sqrt(N)
    z2_norm = ((z2-z2.mean(0)) / z2.std(0)) / math.sqrt(N)

    c1 = torch.mm(z1_norm.T, z1_norm)
    c2 = torch.mm(z2_norm.T, z2_norm)

    iden = torch.eye(D, device=device)
    loss_inv = (z1_norm - z2_norm).pow(2).sum()
    loss_dec_1 = (c1 - iden).pow(2).sum()
    loss_dec_2 = (c2 - iden).pow(2).sum()
    loss_dec = loss_dec_1 + loss_dec_2

    loss = loss_inv + beta * loss_dec
    return loss


def KNN_graph(x, k=12):
    h = F.normalize(x, dim=-1)
    device = x.device
    logits = torch.matmul(h, h.t())
    _, indices = torch.topk(logits, k=k, dim=-1)
    graph = torch.zeros(h.shape[0], h.shape[0], dtype=torch.int64, device=device).scatter_(1, indices, 1)
    
    edge_index = torch.nonzero(graph).t()
    edge_index = to_undirected(edge_index)
    return edge_index


def edge_drop(edge_index, p=0.4):
    edge_index = copy.deepcopy(edge_index)
    num_edges = edge_index.size(1)
    num_droped = int(num_edges*p)
    perm = torch.randperm(num_edges)
    edge_index = edge_index[:, perm[:num_edges-num_droped]]
    return edge_index


def feat_drop(x, p=0.2):
    x = copy.deepcopy(x)
    mask = torch.empty((x.size(1),), dtype=torch.float32, device=x.device).uniform_(0, 1) < p
    x[:, mask] = 0
    return x


def unsupervised_learning(data, args):
    best = float("inf")
    cnt_wait = 0
    unsup_tag = str(int(time.time()))
    for epoch in range(args.unsup_epochs):
        model.train()
        optimizer.zero_grad()

        x1, x2 = data.x, feat_drop(data.x, p=args.feat_drop)
        
        knn_edge_index = KNN_graph(data.x, k=args.k)
        drop_edge_index = edge_drop(data.edge_index, p=args.edge_drop)

        learned_adj = graph_learner(data.x)
        learned_edge_index = torch.nonzero(learned_adj).t()
        learned_edge_weight = learned_adj[learned_edge_index[0], learned_edge_index[1]]

        h1 = model(x1)
        h2 = model(x2)

        if args.topology_augmentation=='learned': # learned adj
            hs1 = model.prop(h1, learned_edge_index, learned_edge_weight)
        elif args.topology_augmentation=='knn': # knn adj
            hs1 = model.prop(h1, knn_edge_index)
        elif args.topology_augmentation=='drop': # drop adj
            hs1 = model.prop(h1, drop_edge_index)
        elif args.topology_augmentation=='init': # init adj
            hs1 = model.prop(h1, data.edge_index)
        else:
            raise ValueError("Unrecognized augmentation")

        loss = args.alpha * model.hierarchial_n2n(h1, hs1) \
            + (1-args.alpha) * args.scale * CCA_SSG(h1, h2, beta=0)
        # print(f"Epoch {epoch} - loss: {loss.item()}")
        loss.backward()
        optimizer.step()
        if loss < best:
            best = loss
            cnt_wait = 0
            torch.save(model.state_dict(), 'unsup_pkl/' + 'hgrl_' + args.net + '_best_model_'+ args.dataset + unsup_tag + '.pkl')
        else:
            cnt_wait += 1

        if cnt_wait == args.patience:
            break

    model.load_state_dict(torch.load('unsup_pkl/' + 'hgrl_' + args.net + '_best_model_'+ args.dataset + unsup_tag + '.pkl'))
    model.eval()
    embeds = model.get_embedding(data.x)
    os.remove('unsup_pkl/' + 'hgrl_' + args.net + '_best_model_'+ args.dataset + unsup_tag + '.pkl')
    return embeds


class GCN_prop(MessagePassing):
    def __init__(self, K, **kwargs):
        super().__init__(aggr='add', **kwargs)
        self.K = K

    def forward(self, x, edge_index, edge_weight=None):
        edge_index, norm = gcn_norm(edge_index, edge_weight, num_nodes=x.size(0), dtype=x.dtype)
        reps = []
        for k in range(self.K):
            x = self.propagate(edge_index, x=x, norm=norm)
            reps.append(x)
        return reps

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

    def __repr__(self):
        return '{}(K={})'.format(self.__class__.__name__, self.K,)


class HGRL(nn.Module):
    def __init__(self, dataset, args):
        super().__init__()
        self.K = args.K
        self.dropout = args.dropout
        self.hidden_size = args.hidden
        self.output_size = args.output_size
        self.input_size = dataset.num_node_features

        # initialize logits
        if args.Init=='random':
            # random
            bound = np.sqrt(3/(self.K))
            logits = np.random.uniform(-bound, bound, self.K)
            logits = logits/np.sum(np.abs(logits))
            self.logits = Parameter(torch.tensor(logits))
            # print(f"init logits: {logits}")
        else:
            # fixed
            logits = np.array([1, float('-inf'), float('-inf')])
            self.logits = torch.tensor(logits)

        self.FFN = nn.Sequential(
            nn.Dropout(self.dropout),
            nn.Linear(self.input_size, self.hidden_size),
            nn.ReLU(),
            nn.Dropout(self.dropout),
            nn.Linear(self.hidden_size, self.output_size)
        )
        self.prop = GCN_prop(self.K)

    def forward(self, x):
        return self.FFN(x)

    @torch.no_grad()
    def get_embedding(self, x):
        self.FFN.eval()
        return self.FFN(x)

    def reset_parameters(self):
        torch.nn.init.zeros_(self.logits)
        bound = np.sqrt(3/(self.K))
        logits = np.random.uniform(-bound, bound, self.K)
        logits = logits/np.sum(np.abs(logits))
        for k in range(self.K):
            self.logits.data[k] = logits[k]
        # kaiming_uniform
        for m in self.modules():
            if isinstance(m, nn.Linear):
                m.reset_parameters()

    def n2n_loss(self, h1, h2, gamma, temperature=1, bias=1e-8):
        # h1: x, h2: n-hop neighbors
        z1 = F.normalize(h1, dim=-1, p=2)
        z2 = gamma*F.normalize(h2, dim=-1, p=2)
        numerator = torch.exp(torch.sum(z1 * z2, dim=1, keepdims=True) / temperature)
        E_1 = torch.matmul(z1, torch.transpose(z1, 1, 0))
        denominator = torch.sum(torch.exp(E_1 / temperature), dim=1, keepdims=True)
        return -torch.mean(torch.log(numerator / (denominator + bias) + bias))

    def hierarchial_n2n(self, h0, hs):
        # h0: x; hs: list of h1, h2 ...hk
        loss = torch.tensor(0, dtype=torch.float32, device=h0.device)
        gamma = F.softmax(self.logits, dim=0)
        for i in range(len(hs)):
            loss += self.n2n_loss(h0, hs[i], gamma[i])
        return loss


EOS = 1e-10
class Attentive(nn.Module):
    def __init__(self, isize):
        super(Attentive, self).__init__()
        self.w = nn.Parameter(torch.ones(isize))

    def forward(self, x):
        return x @ torch.diag(self.w)


class ATT_learner(nn.Module):
    def __init__(self, nlayers, isize, k, knn_metric, i, sparse, mlp_act):
        super(ATT_learner, self).__init__()
        self.i = i
        self.layers = nn.ModuleList()
        for _ in range(nlayers):
            self.layers.append(Attentive(isize))
        self.k = k
        self.knn_metric = knn_metric
        self.non_linearity = 'relu'
        self.sparse = sparse
        self.mlp_act = mlp_act

    def internal_forward(self, h):
        for i, layer in enumerate(self.layers):
            h = layer(h)
            if i != (len(self.layers) - 1):
                if self.mlp_act == "relu":
                    h = F.relu(h)
                elif self.mlp_act == "tanh":
                    h = F.tanh(h)
        return h

    def forward(self, features):
        embeddings = self.internal_forward(features)
        embeddings = F.normalize(embeddings, dim=1, p=2)
        similarities = cal_similarity_graph(embeddings)
        similarities = top_k(similarities, self.k + 1)
        similarities = apply_non_linearity(similarities, self.non_linearity, self.i)

        learned_adj = symmetrize(similarities)
        learned_adj = normalize(learned_adj, 'sym', sparse=self.sparse)
        return learned_adj

def apply_non_linearity(tensor, non_linearity, i):
    if non_linearity == 'elu':
        return F.elu(tensor * i - i) + 1
    elif non_linearity == 'relu':
        return F.relu(tensor)
    elif non_linearity == 'none':
        return tensor
    else:
        raise NameError('We dont support the non-linearity yet')

def cal_similarity_graph(node_embeddings):
    similarity_graph = torch.mm(node_embeddings, node_embeddings.t())
    return similarity_graph

def top_k(raw_graph, K):
    values, indices = raw_graph.topk(k=int(K), dim=-1)
    assert torch.max(indices) < raw_graph.shape[1]
    mask = torch.zeros(raw_graph.shape, device=raw_graph.device)
    mask[torch.arange(raw_graph.shape[0]).view(-1, 1), indices] = 1.

    mask.requires_grad = False
    sparse_graph = raw_graph * mask
    return sparse_graph

def symmetrize(adj):  # only for non-sparse
    return (adj + adj.T) / 2

def normalize(adj, mode, sparse=False):
    if not sparse:
        if mode == "sym":
            inv_sqrt_degree = 1. / (torch.sqrt(adj.sum(dim=1, keepdim=False)) + EOS)
            return inv_sqrt_degree[:, None] * adj * inv_sqrt_degree[None, :]
        elif mode == "row":
            inv_degree = 1. / (adj.sum(dim=1, keepdim=False) + EOS)
            return inv_degree[:, None] * adj
        else:
            exit("wrong norm mode")
    else:
        adj = adj.coalesce()
        if mode == "sym":
            inv_sqrt_degree = 1. / (torch.sqrt(torch.sparse.sum(adj, dim=1).values()))
            D_value = inv_sqrt_degree[adj.indices()[0]] * inv_sqrt_degree[adj.indices()[1]]

        elif mode == "row":
            aa = torch.sparse.sum(adj, dim=1)
            bb = aa.values()
            inv_degree = 1. / (torch.sparse.sum(adj, dim=1).values() + EOS)
            D_value = inv_degree[adj.indices()[0]]
        else:
            exit("wrong norm mode")
        new_values = adj.values() * D_value
        return torch.sparse.FloatTensor(adj.indices(), new_values, adj.size())
    

if __name__ == '__main__':
    args = parse_args()
    print(args)
    print("---------------------------------------------")
    
    set_seed(args.seed)
    #10 fixed seeds for random splits from BernNet
    SEEDS=[1941488137,4198936517,983997847,4023022221,4019585660,2108550661,1648766618,629014539,3212139042,2424918363]
    device = torch.device('cuda:'+str(args.device) if torch.cuda.is_available() else 'cpu')

    root = './data/'
    dataset = HeterophilousGraphDataset(root=root, name=args.dataset)
    data = dataset[0]
    data = data.to(device)

    percls_trn = int(round(args.train_rate * len(data.y) / dataset.num_classes))
    val_lb = int(round(args.val_rate * len(data.y)))
    
    model = HGRL(dataset, args).to(device)
    graph_learner = ATT_learner(2, isize=dataset.num_node_features, k=args.k, knn_metric=args.knn_metric,
                                i=6, sparse=args.sparse, mlp_act=args.activation_learner).to(device)

    optimizer = torch.optim.Adam([
        {'params': filter(lambda x: x is not model.logits, model.parameters()),
         'lr': args.lr1,
         'weight_decay': args.wd1
        },
        {'params': [model.logits],
         'lr': args.lr_gamma,
         'weight_decay': 0.0
        },
        {'params': graph_learner.parameters(),
         'lr': args.lr1,
         'weight_decay': args.wd1
        }
    ], lr=args.lr1, weight_decay=args.wd1)

    embeds = unsupervised_learning(data=data, args=args)
    results = unsupervised_eval_linear(data=data, embeds=embeds, args=args, device=device)
    results = [v.item() for v in results]
    test_acc_mean = np.mean(results, axis=0) * 100
    values = np.asarray(results, dtype=object)
    uncertainty = np.max(
        np.abs(sns.utils.ci(sns.algorithms.bootstrap(values, func=np.mean, n_boot=1000), 95) - values.mean()))
    print(f'test acc mean = {test_acc_mean:.4f} ± {uncertainty * 100:.4f}')