import os
import warnings
import time
import random
import copy
import argparse
import numpy as np
import seaborn as sns
import networkx as nx

import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl.nn import GraphConv
from torch_geometric.utils import to_undirected

from models import *
from utils import set_seed
from dataset_loader import HeterophilousGraphDataset
from eval import unsupervised_eval_linear
warnings.filterwarnings("ignore")



def contrastive_loss(projected_emd, v_emd, sampled_embeddings_u, sampled_embeddings_neg_v, tau, lambda_loss):
    pos = torch.exp(torch.bmm(projected_emd, sampled_embeddings_u.transpose(-1, -2)).squeeze()/tau)
    neg_score = torch.log(pos + torch.sum(torch.exp(torch.bmm(v_emd, sampled_embeddings_neg_v.transpose(-1, -2)).squeeze()/tau), dim=1).unsqueeze(-1))
    neg_score = torch.sum(neg_score, dim=1)
    pos_socre = torch.sum(torch.log(pos), dim=1)
    total_loss = torch.sum(lambda_loss * neg_score - pos_socre)
    total_loss = total_loss/sampled_embeddings_u.shape[0]/sampled_embeddings_u.shape[1]
    return total_loss


class PairNorm(nn.Module):
    def __init__(self, mode='PN', scale=10):
        assert mode in ['None', 'PN', 'PN-SI', 'PN-SCS']
        super(PairNorm, self).__init__()
        self.mode = mode
        self.scale = scale

    def forward(self, x):
        if self.mode == 'None':
            return x
        col_mean = x.mean(dim=0)
        if self.mode == 'PN':
            x = x - col_mean
            rownorm_mean = (1e-6 + x.pow(2).sum(dim=1).mean()).sqrt()
            x = self.scale * x / rownorm_mean
        if self.mode == 'PN-SI':
            x = x - col_mean
            rownorm_individual = (1e-6 + x.pow(2).sum(dim=1, keepdim=True)).sqrt()
            x = self.scale * x / rownorm_individual
        if self.mode == 'PN-SCS':
            rownorm_individual = (1e-6 + x.pow(2).sum(dim=1, keepdim=True)).sqrt()
            x = self.scale * x / rownorm_individual - col_mean
        return x


class MLP_generator(nn.Module):
    def __init__(self, input_dim, output_dim, num_layers):
        super(MLP_generator, self).__init__()
        self.linears = torch.nn.ModuleList()
        self.linears.append(nn.Linear(input_dim, output_dim))
        for layer in range(num_layers - 1):
            self.linears.append(nn.Linear(output_dim, output_dim))
        self.num_layers = num_layers

    def forward(self, embedding):
        h = embedding
        for layer in range(self.num_layers - 1):
            h = F.relu(self.linears[layer](h))
        neighbor_embedding = self.linears[self.num_layers - 1](h)
        return neighbor_embedding
    

class EMA():
    def __init__(self, beta):
        super().__init__()
        self.beta = beta
    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

def update_moving_average(target_ema_updater, ma_model, current_model):
    for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
        old_weight, up_weight = ma_params.data, current_params.data
        ma_params.data = target_ema_updater.update_average(old_weight, up_weight)

def set_requires_grad(model, val):
    for p in model.parameters():
        p.requires_grad = val


class GraphACL(nn.Module):
    def __init__(self, in_dim, hidden_dim, sample_size, tau, norm_mode="PN-SCS", norm_scale=20, lambda_loss=1,moving_average_decay=0.0, num_MLP=3):
        super(GraphACL, self).__init__()
        self.norm = PairNorm(norm_mode, norm_scale)
        self.out_dim = hidden_dim
        self.lambda_loss = lambda_loss
        self.tau = tau
        # GNN Encoder
        self.graphconv1 = GraphConv(in_dim, hidden_dim)
        self.graphconv2 = GraphConv(hidden_dim, hidden_dim)
        self.target_graphconv1 = copy.deepcopy(self.graphconv1)
        self.target_graphconv2 = copy.deepcopy(self.graphconv2)
        set_requires_grad(self.target_graphconv1, False)
        set_requires_grad(self.target_graphconv2, False)

        self.target_ema_updater = EMA(moving_average_decay)
        self.num_MLP = num_MLP
        self.projector = MLP_generator(hidden_dim, hidden_dim, num_MLP)

        self.in_dim = in_dim
        self.sample_size = sample_size

    def update_moving_average(self):
        # assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum for the target encoder'
        assert self.target_graphconv1 or self.target_graphconv2 is not None, 'target encoder has not been created yet'
        update_moving_average(self.target_ema_updater, self.target_graphconv1, self.graphconv1)
        update_moving_average(self.target_ema_updater, self.target_graphconv2, self.graphconv2)

    def forward_encoder(self, g, h):
        v_emd = self.graphconv2(g, self.graphconv1(g, h))
        u_emd = F.normalize(self.target_graphconv2(g, self.target_graphconv1(g, h)), p=2, dim=-1)
        projected_emd = F.normalize(self.projector(v_emd), p=2, dim=-1)

        return v_emd, u_emd, projected_emd

    def get_emb(self, neighbor_indexes, gt_embeddings):
        sampled_embeddings = []
        if len(neighbor_indexes) < self.sample_size:
            sample_indexes = neighbor_indexes
            sample_indexes += np.random.choice(neighbor_indexes, self.sample_size - len(sample_indexes)).tolist()
        else:
            sample_indexes = random.sample(neighbor_indexes, self.sample_size)

        for index in sample_indexes:
            sampled_embeddings.append(gt_embeddings[index])

        return torch.stack(sampled_embeddings)

    # Sample neighbors from neighbor set, if the length of neighbor set less than sample size, then do the padding.
    def sample_neighbors(self, neighbor_dict, u_emd,v_emd):
        sampled_embeddings_list = []
        sampled_embeddings_neg_list = []
        for index, embedding in enumerate(u_emd):
            neighbor_indexes = neighbor_dict[index]
            sampled_embeddings = self.get_emb(neighbor_indexes, u_emd)
            sampled_embeddings_list.append(sampled_embeddings)
            sampled_neg_embeddings = self.get_emb(range(0, len(neighbor_dict)), v_emd)
            sampled_embeddings_neg_list.append(sampled_neg_embeddings)

        return torch.stack(sampled_embeddings_list), torch.stack(sampled_embeddings_neg_list)

    def reconstruction_neighbors(self, projected_emd, u_emd,v_emd, neighbor_dict):
        sampled_embeddings_u, sampled_embeddings_neg_v = self.sample_neighbors(neighbor_dict, u_emd, F.normalize(v_emd, p=2, dim=-1))
        projected_emd = projected_emd.unsqueeze(1)
        neighbor_recons_loss = contrastive_loss(projected_emd, F.normalize(v_emd, p=2, dim=-1).unsqueeze(1), sampled_embeddings_u, sampled_embeddings_neg_v, self.tau, self.lambda_loss)
        return neighbor_recons_loss

    def neighbor_decoder(self, v_emd, u_emd, projected_emd, neighbor_dict):
        neighbor_recons_loss = self.reconstruction_neighbors(projected_emd, u_emd, v_emd, neighbor_dict)
        return neighbor_recons_loss
    
    def get_embeddings(self, g, h):
        v_emd, u_emd, projected_emd = self.forward_encoder(g, h)
        return v_emd.detach()

    def forward(self, g, h, neighbor_dict):
        v_emd, u_emd, projected_emd = self.forward_encoder(g, h)
        loss = self.neighbor_decoder(v_emd, u_emd, projected_emd,neighbor_dict)
        return loss

    

def unsupervised_learning(graph, args):
    feats = graph.ndata['attr']
    in_nodes, out_nodes = graph.edges()
    neighbor_dict = {}
    for in_node, out_node in zip(in_nodes, out_nodes):
        if in_node.item() not in neighbor_dict:
            neighbor_dict[in_node.item()] = []
        neighbor_dict[in_node.item()].append(out_node.item())
    
    in_dim = feats.shape[1]
    model = GraphACL(in_dim, args.hidden, args.sample_size, lambda_loss=args.lambda_loss, tau=args.tau,  
                        moving_average_decay=args.moving_average_decay, num_MLP=args.num_MLP).to(device)
    opt = torch.optim.Adam([{'params': model.parameters()}], lr=args.lr1, weight_decay=args.wd1)

    best = float("inf")
    cnt_wait = 0
    unsup_tag = str(int(time.time()))
    for i in range(args.unsup_epochs):
        loss = model(graph, feats, neighbor_dict)
        opt.zero_grad()
        loss.backward()
        # print(f'Epoch {i}, loss={loss.item():.4f}')

        opt.step()
        model.update_moving_average()

        if loss < best:
            best = loss
            cnt_wait = 0
            torch.save(model.state_dict(), 'unsup_pkl/' + 'graphacl_' + args.net + '_best_model_'+ args.dataset + unsup_tag + '.pkl')
        else:
            cnt_wait += 1

        if cnt_wait == args.patience:
            break

    model.load_state_dict(torch.load('unsup_pkl/' + 'graphacl_' + args.net + '_best_model_'+ args.dataset + unsup_tag + '.pkl'))
    model.eval()
    embeds = model.get_embeddings(graph, feats)
    os.remove('unsup_pkl/' + 'graphacl_' + args.net + '_best_model_'+ args.dataset + unsup_tag + '.pkl')
    return embeds


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42, help='seed.')
    parser.add_argument('--dataset', type=str,default='Cora')
    parser.add_argument('--device', type=int, default=0, help='GPU device.')
    parser.add_argument('--runs', type=int, default=10, help='number of runs.')
    parser.add_argument('--net', type=str, default='GCN')

    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--hidden', type=int, default=512, help='hidden units.')
    parser.add_argument('--dropout', type=float, default=0.0, help='dropout for neural networks.')
    
    parser.add_argument('--train_rate', type=float, default=0.6, help='train set rate.')
    parser.add_argument('--val_rate', type=float, default=0.2, help='val set rate.')

    # unsupervised learning
    parser.add_argument("--patience", type=int, default=50, help="Patient epochs to wait before early stopping.")
    parser.add_argument("--unsup_epochs", type=int, default=1000, help="Unupservised training epochs.")
    parser.add_argument("--lr1", type=float, default=5e-4, help="Learning rate of the unsupervised model.")
    parser.add_argument("--lr2", type=float, default=0.01, help="Learning rate of linear evaluator.")
    parser.add_argument("--wd1", type=float, default=0.0003, help="Weight decay of the unsupervised model.")
    parser.add_argument("--wd2", type=float, default=0.0, help="Weight decay of linear evaluator.")
    parser.add_argument('--tau', type=float, default=0.5)

    parser.add_argument('--temp', type=float, default=0.5, help='Temperature hyperparameter.')
    parser.add_argument('--moving_average_decay', type=float, default=0.9)
    parser.add_argument('--num_MLP', type=int, default=1)
    parser.add_argument('--lambda_loss', type=float, default=1)
    parser.add_argument('--sample_size', type=int, default=5)
    args = parser.parse_args()
    return args


if __name__ == '__main__':
    args = parse_args()
    print(args)
    print("---------------------------------------------")
    
    set_seed(args.seed)
    #10 fixed seeds for random splits from BernNet
    SEEDS=[1941488137,4198936517,983997847,4023022221,4019585660,2108550661,1648766618,629014539,3212139042,2424918363]
    device = torch.device('cuda:'+str(args.device) if torch.cuda.is_available() else 'cpu')

    root = './data/'
    dataset = HeterophilousGraphDataset(root=root, name=args.dataset)
    data = dataset[0]
    data.edge_index = to_undirected(data.edge_index)
    G = nx.from_edgelist(data.edge_index.transpose(0, 1).numpy().tolist())
    graph = dgl.from_networkx(G)
    graph.ndata['attr'] = data.x
    graph = graph.to(device)

    percls_trn = int(round(args.train_rate * len(data.y) / dataset.num_classes))
    val_lb = int(round(args.val_rate * len(data.y)))

    embeds = unsupervised_learning(graph=graph, args=args)
    
    results = unsupervised_eval_linear(data=data, embeds=embeds, args=args, device=device)
    results = [v.item() for v in results]
    test_acc_mean = np.mean(results, axis=0) * 100
    values = np.asarray(results, dtype=object)
    uncertainty = np.max(
        np.abs(sns.utils.ci(sns.algorithms.bootstrap(values, func=np.mean, n_boot=1000), 95) - values.mean()))
    print(f'test acc mean = {test_acc_mean:.4f} ± {uncertainty * 100:.4f}')