import os
import warnings
import time
import numpy as np
import argparse
import seaborn as sns
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn
from dgl.nn import GraphConv
from torch_geometric.data import Data

from models import *
from utils import set_seed, random_splits
from dataset_loader import DataLoader
from eval import unsupervised_test_linear
warnings.filterwarnings("ignore")


def pyg_to_dgl(data):
    assert isinstance(data, Data)
    edge_index = data.edge_index
    src = edge_index[0].numpy()
    dst = edge_index[1].numpy()
    g = dgl.graph((src, dst), num_nodes=data.num_nodes)
    if data.x is not None:
        g.ndata['feat'] = data.x
    if data.edge_attr is not None:
        g.edata['feat'] = data.edge_attr
    return g


class MLP(nn.Module):
    def __init__(self, nfeat, nhid, nclass, use_bn=True):
        super(MLP, self).__init__()

        self.layer1 = nn.Linear(nfeat, nhid, bias=True)
        self.layer2 = nn.Linear(nhid, nclass, bias=True)

        self.bn = nn.BatchNorm1d(nhid)
        self.use_bn = use_bn
        self.act_fn = nn.ReLU()

    def forward(self, _, x):
        x = self.layer1(x)
        if self.use_bn:
            x = self.bn(x)
        x = self.act_fn(x)
        x = self.layer2(x)
        return x


class GCN(nn.Module):
    def __init__(self, in_dim, hid_dim, n_layers, dropout):
        super().__init__()
        self.n_layers = n_layers
        self.dropout = dropout
        self.convs = nn.ModuleList()
        self.convs.append(GraphConv(in_dim, hid_dim, norm='both'))
        for i in range(n_layers - 1):
            self.convs.append(GraphConv(hid_dim, hid_dim, norm='both'))

    def forward(self, graph, x):
        for i in range(self.n_layers - 1):
            x = F.relu(self.convs[i](graph, x))
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](graph, x)
        return x


class EMA():
    def __init__(self, beta):
        super().__init__()
        self.beta = beta

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

def update_moving_average(target_ema_updater, ma_model, current_model):
    for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
        old_weight, up_weight = ma_params.data, current_params.data
        ma_params.data = target_ema_updater.update_average(old_weight, up_weight)

def set_requires_grad(model, val):
    for p in model.parameters():
        p.requires_grad = val


class Predictor(nn.Module):
    def __init__(self, input_dim, output_dim, num_layers):
        super(Predictor, self).__init__()
        self.linears = torch.nn.ModuleList()
        self.linears.append(nn.Linear(input_dim, output_dim))
        for layer in range(num_layers - 1):
            self.linears.append(nn.Linear(output_dim, output_dim))
        self.num_layers = num_layers

    def forward(self, embedding):
        h = embedding
        for layer in range(self.num_layers - 1):
            h = F.relu(self.linears[layer](h))
        h = self.linears[self.num_layers - 1](h)
        return h

def udf_u_add_log_e(edges):
    return {'m': torch.log(edges.dst['neg_sim'] + edges.data['sim'])}


class GraphACL(nn.Module):
    def __init__(self, encoder, encoder_target, hid_dim, temp, moving_average_decay=1.0, num_MLP=1):
        super(GraphACL, self).__init__()
        self.encoder = encoder
        self.encoder_target = encoder_target
        set_requires_grad(self.encoder_target, False)

        self.temp = temp
        self.target_ema_updater = EMA(moving_average_decay)
        self.num_MLP = num_MLP
        self.projector = Predictor(hid_dim, hid_dim, num_MLP)

    def get_embedding(self, graph, feat):
        h = self.encoder(graph, feat)
        return h.detach()

    def pos_score(self, graph, v, u):
        graph.ndata['q'] = F.normalize(self.projector(v))
        graph.ndata['u'] = F.normalize(u, dim=-1)
        graph.apply_edges(fn.u_mul_v('u', 'q', 'sim'))
        graph.edata['sim'] = graph.edata['sim'].sum(1) / self.temp
        graph.update_all(fn.copy_e('sim', 'm'), fn.mean('m', 'pos'))
        pos_score = graph.ndata['pos']
        return pos_score, graph

    def neg_score(self, h, graph, rff_dim=None):
        z = F.normalize(h, dim=-1)
        graph.edata['sim'] = torch.exp(graph.edata['sim'])
        neg_sim = torch.exp(torch.mm(z, z.t()) / self.temp)
        neg_score = neg_sim.sum(1)
        graph.ndata['neg_sim'] = neg_score
        graph.update_all(udf_u_add_log_e, fn.mean('m', 'neg'))
        neg_score = graph.ndata['neg']
        return neg_score

    def update_moving_average(self):
        # assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum for the target encoder'
        assert self.encoder_target is not None, 'target encoder has not been created yet'
        update_moving_average(self.target_ema_updater, self.encoder_target, self.encoder)

    def forward(self, graph, feat):
        v = self.encoder(graph, feat)
        u = self.encoder_target(graph, feat)
        pos_score, graph = self.pos_score(graph, v, u)
        neg_score = self.neg_score(v, graph)
        loss = (- pos_score + neg_score).mean()
        return loss


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42, help='seed.')
    parser.add_argument('--dataset', type=str,default='Cora')
    parser.add_argument('--device', type=int, default=0, help='GPU device.')
    parser.add_argument('--runs', type=int, default=10, help='number of runs.')
    parser.add_argument('--net', type=str, default='GCN')

    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--hidden', type=int, default=512, help='hidden units.')
    parser.add_argument('--dropout', type=float, default=0.0, help='dropout for neural networks.')
    
    parser.add_argument('--train_rate', type=float, default=0.6, help='train set rate.')
    parser.add_argument('--val_rate', type=float, default=0.2, help='val set rate.')

    # unsupervised learning
    parser.add_argument("--patience", type=int, default=50, help="Patient epochs to wait before early stopping.")
    parser.add_argument("--unsup_epochs", type=int, default=1000, help="Unupservised training epochs.")
    parser.add_argument("--lr1", type=float, default=5e-4, help="Learning rate of the unsupervised model.")
    parser.add_argument("--lr2", type=float, default=0.01, help="Learning rate of linear evaluator.")
    parser.add_argument("--wd1", type=float, default=0.0, help="Weight decay of the unsupervised model.")
    parser.add_argument("--wd2", type=float, default=0.0, help="Weight decay of linear evaluator.")

    parser.add_argument('--temp', type=float, default=0.5, help='Temperature hyperparameter.')
    parser.add_argument('--moving_average_decay', type=float, default=0.9)
    parser.add_argument('--num_MLP', type=int, default=1)
    args = parser.parse_args()
    return args


def unsupervised_learning(graph, args):
    feat = graph.ndata["feat"]
    best = float("inf")
    cnt_wait = 0
    unsup_tag = str(int(time.time()))
    for epoch in range(args.unsup_epochs):
        model.train()
        optimizer.zero_grad()
        loss = model(graph, feat)
        loss.backward()
        optimizer.step()
        model.update_moving_average()
        # print('Epoch={:03d}, loss={:.4f}'.format(epoch, loss.item()))
        if loss < best:
            best = loss
            cnt_wait = 0
            torch.save(model.state_dict(), 'unsup_pkl/' + 'graphacl_' + args.net + '_best_model_'+ args.dataset + unsup_tag + '.pkl')
        else:
            cnt_wait += 1

        if cnt_wait == args.patience:
            break

    model.load_state_dict(torch.load('unsup_pkl/' + 'graphacl_' + args.net + '_best_model_'+ args.dataset + unsup_tag + '.pkl'))
    model.eval()
    embeds = model.get_embedding(graph, feat)
    os.remove('unsup_pkl/' + 'graphacl_' + args.net + '_best_model_'+ args.dataset + unsup_tag + '.pkl')
    return embeds



if __name__ == '__main__':
    args = parse_args()
    print(args)
    print("---------------------------------------------")
    
    set_seed(args.seed)
    #10 fixed seeds for random splits from BernNet
    SEEDS=[1941488137,4198936517,983997847,4023022221,4019585660,2108550661,1648766618,629014539,3212139042,2424918363]
    device = torch.device('cuda:'+str(args.device) if torch.cuda.is_available() else 'cpu')

    dataset = DataLoader(args.dataset)
    data = dataset[0]
    graph = pyg_to_dgl(data).to(device)

    percls_trn = int(round(args.train_rate * len(data.y) / dataset.num_classes))
    val_lb = int(round(args.val_rate * len(data.y)))
    
    encoder = GCN(dataset.num_node_features, args.hidden, args.num_layers, args.dropout)
    encoder_target = copy.deepcopy(encoder)
    model = GraphACL(encoder, encoder_target, args.hidden, args.temp, args.moving_average_decay, args.num_MLP).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr1, weight_decay=args.wd1)
    graph = graph.remove_self_loop().add_self_loop()

    embeds = unsupervised_learning(graph=graph, args=args)
    
    unsup_results = []
    for RP in range(args.runs):
        args.seed = SEEDS[RP]
        tag = str(args.seed)
        data = random_splits(data, dataset.num_classes, percls_trn, val_lb, args.seed).to(device)
        eval_acc = unsupervised_test_linear(data=data, embeds=embeds, n_classes=dataset.num_classes, device=device, args=args)
        unsup_results.append(eval_acc)

    test_acc_mean = np.mean(unsup_results) * 100
    values = np.asarray(unsup_results, dtype=object)
    uncertainty = np.max(np.abs(sns.utils.ci(sns.algorithms.bootstrap(values, func=np.mean, n_boot=1000), 95) - values.mean()))
    print(f'test acc mean = {test_acc_mean:.4f} ± {uncertainty * 100:.4f}')



