import argparse
import random
import copy
import time
import math
import os
import numpy as np
from sklearn import preprocessing as sk_prep

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
import dgl
from dgl.data import register_data_args
from dgl.nn import EdgeWeightNorm
from dgl.nn.pytorch import GraphConv
from dgl import function as fn
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator

from gcn_aggr import GraphConvAGGR


class GCNAGGR(nn.Module):
    def __init__(self, in_feats, n_layers, activation, dropout):
        super(GCNAGGR, self).__init__()
        self.layers = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.layers.append(GraphConvAGGR(in_feats, activation=activation))
        self.bns.append(torch.nn.BatchNorm1d(in_feats, momentum = 0.01))
        for i in range(n_layers - 1):
            self.layers.append(GraphConvAGGR(in_feats, activation=activation))
            self.bns.append(torch.nn.BatchNorm1d(in_feats, momentum = 0.01))
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, g, features):
        h = features
        for i, layer in enumerate(self.layers):
            if i != 0:
                h = self.dropout(h)
            h = layer(g, h)
        return h
    
    
    
class GCN(nn.Module):
    def __init__(self, in_feats, n_hidden, n_layers, activation, dropout, bias = True, weight=True):
        super(GCN, self).__init__()
        self.layers = nn.ModuleList()
        self.bns = torch.nn.ModuleList()
        self.layers.append(GraphConv(in_feats, n_hidden, weight = weight, bias = bias, activation=activation))
        self.bns.append(torch.nn.BatchNorm1d(n_hidden, momentum = 0.01))
        for i in range(n_layers - 1):
            self.layers.append(GraphConv(n_hidden, n_hidden, weight=weight, bias=bias, activation=activation))
            self.bns.append(torch.nn.BatchNorm1d(n_hidden, momentum = 0.01))
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, g, features):
        h = features
        for i, layer in enumerate(self.layers):
            if i != 0:
                h = self.dropout(h)
            h = layer(g, h)
        return h


class ChebNetII_V2(nn.Module):
    def __init__(self, K):
        super(ChebNetII_V2, self).__init__()
        self.K = K

    def chebyshev_polynomials(self, g, x):
        Tx_0 = x
        Tx_1 = self.propagate(g, x)
        out = Tx_0 + Tx_1
        for k in range(2, self.K + 1):
            Tx_2 = 2 * self.propagate(g, Tx_1) - Tx_0
            out = out + Tx_2
            Tx_0, Tx_1 = Tx_1, Tx_2
        return out

    def propagate(self, g, x):
        g.ndata['h'] = x
        g.update_all(dgl.function.copy_u('h', 'm'), dgl.function.sum('m', 'h_new'))
        return g.ndata.pop('h_new')

    def forward(self, g, x):
        x = self.chebyshev_polynomials(g, x)
        return F.relu(x)



class GPRGNN_V2(nn.Module):
    def __init__(self, K):
        super(GPRGNN_V2, self).__init__()
        self.K = K
        self.alpha = nn.Parameter(torch.Tensor(K + 1))  # Learnable coefficients
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.constant_(self.alpha, 1.0)
        # for j in range(self.K+1):
        #     x_j=math.cos((self.K-j+0.5)*math.pi/(self.K+1))
        #     self.alpha.data[j] = x_j**2
        # nn.init.xavier_uniform_(self.alpha)  # Xavier initialization


    def propagate(self, g, x):
        g.ndata['h'] = x
        g.update_all(dgl.function.copy_u('h', 'm'), dgl.function.sum('m', 'h_new'))
        return g.ndata.pop('h_new')

    def forward(self, g, x):
        Tx_0 = x
        out = self.alpha[0] * Tx_0
        Tx = Tx_0

        for k in range(1, self.K + 1):
            Tx = self.propagate(g, Tx)
            out = out + self.alpha[k] * Tx
        return out


class BernNet_V2(nn.Module):
    def __init__(self, K):
        super(BernNet_V2, self).__init__()
        self.K = K
        self.beta = nn.Parameter(torch.Tensor(K + 1))  # Learnable Bernstein coefficients
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.constant_(self.beta, 1.0)

    def bernstein_polynomials(self, g, x):
        B = []
        for k in range(self.K + 1):
            B_k = (math.comb(self.K, k) * (x ** k) * ((1 - x) ** (self.K - k)))
            B.append(B_k)
        return sum(self.beta[k] * B[k] for k in range(self.K + 1))

    def propagate(self, g, x):
        g.ndata['h'] = x
        g.update_all(dgl.function.copy_u('h', 'm'), dgl.function.sum('m', 'h_new'))
        return g.ndata.pop('h_new')

    def forward(self, g, x):
        x = self.bernstein_polynomials(g, x)
        return F.relu(x)




class DGI(nn.Module):
    def __init__(self, encoder, hidden_dim):
        super(DGI, self).__init__()
        self.encoder = encoder
        self.fn = nn.Bilinear(hidden_dim, hidden_dim, 1)
        self.act_fn = nn.ReLU()

    def corruption(self, features):
        perm = torch.randperm(features.size(0))
        return features[perm]

    def forward(self, g, features, labels, loss_func):
        positive = self.encoder(g, features)
        negative = self.encoder(g, self.corruption(features))
        
        g = self.act_fn(torch.mean(positive, dim=0))
        g_x = g.expand_as(positive).contiguous()

        sc_1 = self.fn(positive, g_x).squeeze(1)
        sc_2 = self.fn(negative, g_x).squeeze(1)

        logits = torch.cat((sc_1, sc_2), dim=0)
        # logits = torch.clamp(logits, min=-1e3, max=1e3)  # Clip logits before passing to the loss
        # print(logits)
        loss = loss_func(logits, labels)
        return loss

    def embed(self, g, features):
        # Use this to get embeddings after training
        h_1 = self.encoder(g, features)
        feat = h_1.clone().squeeze(0)

        degs = g.in_degrees().float().clamp(min=1)
        norm = torch.pow(degs, -0.5)
        norm = norm.to(h_1.device).unsqueeze(1)

        # Graph propagation to reinforce embeddings
        for _ in range(10):
            feat = feat * norm
            g.ndata['h2'] = feat
            g.update_all(dgl.function.copy_u('h2', 'm'), dgl.function.sum('m', 'h2'))
            feat = g.ndata.pop('h2')
            feat = feat * norm

        h_2 = feat.unsqueeze(0)
        return h_1.detach(), h_2.detach()


class Classifier(nn.Module):
    def __init__(self, n_hidden, n_classes):
        super(Classifier, self).__init__()
        self.fc = nn.Linear(n_hidden, n_classes)
        self.reset_parameters()

    def reset_parameters(self):
        self.fc.reset_parameters()

    def forward(self, features):
        features = self.fc(features)
        return torch.log_softmax(features, dim=-1)
    

def aug_feature_dropout(input_feat, drop_percent=0.2):
    # aug_input_feat = copy.deepcopy((input_feat.squeeze(0)))
    aug_input_feat = copy.deepcopy(input_feat)
    drop_feat_num = int(aug_input_feat.shape[1] * drop_percent)
    drop_idx = random.sample([i for i in range(aug_input_feat.shape[1])], drop_feat_num)
    aug_input_feat[:, drop_idx] = 0
    return aug_input_feat


def evaluate(model, features, labels, mask):
    model.eval()
    with torch.no_grad():
        logits = model(features)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)


def load_data_ogb(dataset, args):
    global n_node_feats, n_classes

    if args.data_root_dir == 'default':
        data = DglNodePropPredDataset(name=dataset)
    else:
        data = DglNodePropPredDataset(name=dataset, root=args.data_root_dir)

    evaluator = Evaluator(name=dataset)

    splitted_idx = data.get_idx_split()
    train_idx, val_idx, test_idx = splitted_idx["train"], splitted_idx["valid"], splitted_idx["test"]
    graph, labels = data[0]

    n_node_feats = graph.ndata["feat"].shape[1]
    n_classes = (labels.max() + 1).item()
    return graph, labels, train_idx, val_idx, test_idx, evaluator


def preprocess(graph):
    global n_node_feats

    # make bidirected
    feat = graph.ndata["feat"]
    graph = dgl.to_bidirected(graph)
    graph.ndata["feat"] = feat

    # add self-loop
    print(f"Total edges before adding self-loop {graph.number_of_edges()}")
    graph = graph.remove_self_loop().add_self_loop()
    print(f"Total edges after adding self-loop {graph.number_of_edges()}")

    graph.create_formats_()
    return graph


def main(args):
    cuda = True
    free_gpu_id = int(args.gpu)
    torch.cuda.set_device(args.gpu)

    # load and preprocess dataset
    g, labels, train_mask, val_mask, test_mask, evaluator = load_data_ogb(args.dataset_name, args)
    g = preprocess(g)
    features = g.ndata['feat']
    labels = labels.T.squeeze(0)

    g, labels, train_idx, val_idx, test_idx, features = map(
        lambda x: x.to(free_gpu_id), (g, labels, train_mask, val_mask, test_mask, features)
    )
    in_feats = g.ndata['feat'].shape[1]
    n_classes = labels.T.max().item() + 1
    n_edges = g.num_edges()

    g = g.to(free_gpu_id)

    if args.gnn_encoder == 'GCN':
        encoder = GCN(in_feats, args.n_hidden, args.n_layers, nn.PReLU(args.n_hidden), args.dropout)
        dgi = DGI(encoder, args.n_hidden)
    if args.gnn_encoder == 'GCN_V2':
        encoder = GCNAGGR(in_feats, args.n_layers, nn.PReLU(in_feats), args.dropout)
        dgi = DGI(encoder, in_feats)
    elif args.gnn_encoder == 'ChebNetII_V2':
        encoder = ChebNetII_V2(K=args.K)
        dgi = DGI(encoder, in_feats)
    elif args.gnn_encoder == 'GPRGNN_V2':
        encoder = GPRGNN_V2(K=args.K)
        dgi = DGI(encoder, in_feats)
    elif args.gnn_encoder == 'BernNet_V2':
        encoder = BernNet_V2(K=args.K)
        dgi = DGI(encoder, in_feats)
    
    if cuda:
        dgi.cuda()

    ggd_optimizer = torch.optim.AdamW(dgi.parameters(), lr=args.ggd_lr, weight_decay=args.weight_decay)
    b_xent = nn.BCEWithLogitsLoss()

    # train deep graph infomax
    cnt_wait = 0
    best = math.inf
    best_t = 0
    counts = 0
    dur = []
    tag = str(int(np.random.random() * 10000000000))

    lbl_1 = torch.ones(features.shape[0]).cuda()  # Positive samples
    lbl_2 = torch.zeros(features.shape[0]).cuda()  # Negative samples
    lbl = torch.cat((lbl_1, lbl_2), dim=0)
    for epoch in range(args.n_ggd_epochs):
        dgi.train()
        if epoch >= 3:
            t0 = time.time()

        ggd_optimizer.zero_grad()
        aug_feat = aug_feature_dropout(features, args.drop_feat)
        loss = dgi(g, aug_feat.cuda(), lbl, b_xent)
        loss.backward()
        ggd_optimizer.step()

        if loss < best:
            best = loss
            best_t = epoch
            cnt_wait = 0
            torch.save(dgi.state_dict(), 'pkl/best_ggd' + tag + '.pkl')
        else:
            cnt_wait += 1

        if cnt_wait == args.patience:
            print('Early stopping!')
            break

        if epoch >= 3:
            dur.append(time.time() - t0)

        # print(f"Epoch {epoch:05d} | Time(s) {np.mean(dur):.4f} | Loss {loss.item():.4f} | ETputs(KTEPS) {n_edges / np.mean(dur) / 1000:.2f}")
        counts += 1

    # print('Training Completed.')


    #graph power embedding reinforcement
    l_embeds, g_embeds= dgi.embed(g, features)
    embeds = (l_embeds + g_embeds).squeeze(0)
    embeds = sk_prep.normalize(X=embeds.cpu().numpy(), norm="l2")
    embeds = torch.FloatTensor(embeds).cuda()

    # create classifier model
    classifier = Classifier(embeds.shape[1], n_classes)
    if cuda:
        classifier.cuda()
    classifier_optimizer = torch.optim.AdamW(classifier.parameters(), lr=args.classifier_lr, weight_decay=args.weight_decay)

    # train classifier
    print('Loading {}th epoch'.format(best_t))
    dgi.load_state_dict(torch.load('pkl/best_ggd' + tag + '.pkl'))

    dur = []
    best_acc, best_val_acc = 0, 0
    # print('Testing Phase ==== Please Wait.')
    for epoch in range(args.n_classifier_epochs):
        classifier.train()
        if epoch >= 3:
            t0 = time.time()

        classifier_optimizer.zero_grad()
        preds = classifier(embeds)
        loss = F.nll_loss(preds[train_mask], labels[train_mask])
        loss.backward()
        classifier_optimizer.step()

        if epoch >= 3:
            dur.append(time.time() - t0)

        val_acc = evaluate(classifier, embeds, labels, val_mask)
        if epoch > 1000:
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                test_acc = evaluate(classifier, embeds, labels, test_mask)
                if test_acc > best_acc:
                    best_acc = test_acc

    print("Valid Accuracy {:.4f}".format(best_val_acc))
    print("Test Accuracy {:.4f}".format(best_acc))
    return best_acc


def get_free_gpu():
    os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp')
    memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()]
    return np.argmax(memory_available)


if __name__ == '__main__':
    import warnings

    warnings.filterwarnings("ignore")

    parser = argparse.ArgumentParser(description='GGD')
    register_data_args(parser)
    parser.add_argument('--data_root_dir', type=str, default='default', help="dir_path for saving graph data. Note that this model use DGL loader so do not mix up with the dir_path for the Pyg one. Use 'default' to save datasets at current folder.")
    parser.add_argument('--dataset_name', type=str, default='cora', help='Dataset name: cora, citeseer, pubmed, cs, phy')
    parser.add_argument("--gpu", type=int, default=0, help="gpu")
    parser.add_argument("--self-loop", action='store_true', help="graph self-loop (default=False)")
    parser.add_argument("--n_trails", type=int, default=5, help="number of trails")
    
    parser.add_argument("--gnn_encoder", type=str, default='gcn', help="choice of gnn encoder")
    parser.add_argument("--dropout", type=float, default=0., help="dropout probability")
    parser.add_argument("--n-hidden", type=int, default=512, help="number of hidden gcn units")
    parser.add_argument("--proj_layers", type=int, default=1, help="number of project linear layers")
    parser.add_argument("--n-layers", type=int, default=1, help="number of hidden gcn layers")
    parser.add_argument('--residual', action='store_true')
    parser.add_argument('--use_bn', action='store_true')
    parser.add_argument('--K', type=int, default=10, help='propagation steps.')
    parser.add_argument("--dprate", type=float, default=0., help="dropout probability")

    parser.add_argument("--ggd-lr", type=float, default=0.001, help="ggd learning rate")
    parser.add_argument("--drop_feat", type=float, default=0.1, help="feature dropout rate")
    parser.add_argument("--classifier-lr", type=float, default=0.05, help="classifier learning rate")
    parser.add_argument("--n-ggd-epochs", type=int, default=500, help="number of training epochs")
    parser.add_argument("--n-classifier-epochs", type=int, default=6000, help="number of training epochs")
    parser.add_argument("--weight-decay", type=float, default=0., help="Weight for L2 loss")
    parser.add_argument("--patience", type=int, default=500, help="early stop patience condition")

    parser.set_defaults(self_loop=False)
    args = parser.parse_args()
    print(args)

    accs = []
    for i in range(args.n_trails):
        accs.append(main(args))
    mean_acc = str(np.array(accs).mean())
    print('mean accuracy:' + mean_acc)

    # file_name = str(args.dataset_name)
    # f = open('result/' + 'result_' + file_name + '.txt', 'a')
    # f.write(str(args) + '\n')
    # f.write(mean_acc + '\n')