import argparse
import random
import copy
import time
import os
import numpy as np
from sklearn import preprocessing as sk_prep

import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl import DGLGraph
from dgl.data import register_data_args
from dgl.nn import EdgeWeightNorm
from dgl.nn.pytorch import GraphConv
import dgl.function as fn
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator

import torch as th
from torch import nn
from torch.nn import init

from dgl import function as fn
from dgl.base import DGLError
from dgl.utils import expand_as_pair


class GraphConvAGGR(nn.Module):
    def __init__(
        self,
        in_feats,
        norm="both",
        activation=None,
        allow_zero_in_degree=False,
    ):
        super(GraphConvAGGR, self).__init__()
        if norm not in ("none", "both", "right", "left"):
            raise DGLError(
                'Invalid norm value. Must be either "none", "both", "right" or "left".'
                ' But got "{}".'.format(norm)
            )
        self._in_feats = in_feats
        self._norm = norm
        self._allow_zero_in_degree = allow_zero_in_degree
        self._activation = activation

    def set_allow_zero_in_degree(self, set_value):
        self._allow_zero_in_degree = set_value

    def forward(self, graph, feat, weight=None, edge_weight=None):
        with graph.local_scope():
            if not self._allow_zero_in_degree:
                if (graph.in_degrees() == 0).any():
                    raise DGLError(
                        "There are 0-in-degree nodes in the graph, "
                        "output for those nodes will be invalid. "
                        "This is harmful for some applications, "
                        "causing silent performance regression. "
                        "Adding self-loop on the input graph by "
                        "calling `g = dgl.add_self_loop(g)` will resolve "
                        "the issue. Setting ``allow_zero_in_degree`` "
                        "to be `True` when constructing this module will "
                        "suppress the check and let the code run."
                    )
            aggregate_fn = fn.copy_u("h", "m")
            if edge_weight is not None:
                assert edge_weight.shape[0] == graph.num_edges()
                graph.edata["_edge_weight"] = edge_weight
                aggregate_fn = fn.u_mul_e("h", "_edge_weight", "m")

            # (BarclayII) For RGCN on heterogeneous graphs we need to support GCN on bipartite.
            feat_src, feat_dst = expand_as_pair(feat, graph)
            if self._norm in ["left", "both"]:
                degs = graph.out_degrees().to(feat_src).clamp(min=1)
                if self._norm == "both":
                    norm = th.pow(degs, -0.5)
                else:
                    norm = 1.0 / degs
                shp = norm.shape + (1,) * (feat_src.dim() - 1)
                norm = th.reshape(norm, shp)
                feat_src = feat_src * norm

            graph.srcdata["h"] = feat_src
            graph.update_all(aggregate_fn, fn.sum(msg="m", out="h"))
            rst = graph.dstdata["h"]

            if self._norm in ["right", "both"]:
                degs = graph.in_degrees().to(feat_dst).clamp(min=1)
                if self._norm == "both":
                    norm = th.pow(degs, -0.5)
                else:
                    norm = 1.0 / degs
                shp = norm.shape + (1,) * (feat_dst.dim() - 1)
                norm = th.reshape(norm, shp)
                rst = rst * norm

            if self._activation is not None:
                rst = self._activation(rst)
            return rst
        
        
class AGGR(nn.Module):
    def __init__(self, in_feats, n_layers, activation, dropout):
        super(AGGR, self).__init__()
        self.n_layers = n_layers
        
        self.layers = nn.ModuleList()
        self.bns = torch.nn.ModuleList()
        for i in range(n_layers):
            self.layers.append(GraphConvAGGR(in_feats, activation=activation))
            self.bns.append(torch.nn.BatchNorm1d(in_feats, momentum = 0.01))
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, g, features):
        if self.n_layers == 0:
            return features
        h = features
        for i, layer in enumerate(self.layers):
            if i != 0:
                h = self.dropout(h)
            h = layer(g, h)
        return h
    
    def embed(self, g, features):
        h_1 = self(g, features)
        feat = h_1.clone().squeeze(0)

        degs = g.in_degrees().float().clamp(min=1)
        norm = torch.pow(degs, -0.5)
        norm = norm.to(h_1.device).unsqueeze(1)
        for _ in range(10):
            feat = feat * norm
            g.ndata['h2'] = feat
            g.update_all(fn.copy_u('h2', 'm'), fn.sum('m', 'h2'))
            feat = g.ndata.pop('h2')
            feat = feat * norm

        h_2 = feat.unsqueeze(0)
        return h_1.detach(), h_2.detach()
    
    
class Classifier(nn.Module):
    def __init__(self, n_hidden, n_classes):
        super(Classifier, self).__init__()
        self.fc = nn.Linear(n_hidden, n_classes)
        self.reset_parameters()

    def reset_parameters(self):
        self.fc.reset_parameters()

    def forward(self, features):
        features = self.fc(features)
        return torch.log_softmax(features, dim=-1)
    

def aug_feature_dropout(input_feat, drop_percent=0.2):
    # aug_input_feat = copy.deepcopy((input_feat.squeeze(0)))
    aug_input_feat = copy.deepcopy(input_feat)
    drop_feat_num = int(aug_input_feat.shape[1] * drop_percent)
    drop_idx = random.sample([i for i in range(aug_input_feat.shape[1])], drop_feat_num)
    aug_input_feat[:, drop_idx] = 0
    return aug_input_feat


def evaluate(model, features, labels, mask):
    model.eval()
    with torch.no_grad():
        logits = model(features)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)


def load_data_ogb(dataset, args):
    global n_node_feats, n_classes

    if args.data_root_dir == 'default':
        data = DglNodePropPredDataset(name=dataset)
    else:
        data = DglNodePropPredDataset(name=dataset, root=args.data_root_dir)

    evaluator = Evaluator(name=dataset)

    splitted_idx = data.get_idx_split()
    train_idx, val_idx, test_idx = splitted_idx["train"], splitted_idx["valid"], splitted_idx["test"]
    graph, labels = data[0]

    n_node_feats = graph.ndata["feat"].shape[1]
    n_classes = (labels.max() + 1).item()
    return graph, labels, train_idx, val_idx, test_idx, evaluator


def preprocess(graph):
    global n_node_feats

    # make bidirected
    feat = graph.ndata["feat"]
    graph = dgl.to_bidirected(graph)
    graph.ndata["feat"] = feat

    # add self-loop
    print(f"Total edges before adding self-loop {graph.number_of_edges()}")
    graph = graph.remove_self_loop().add_self_loop()
    print(f"Total edges after adding self-loop {graph.number_of_edges()}")

    graph.create_formats_()
    return graph


def main(args):
    cuda = True
    free_gpu_id = int(args.gpu)
    torch.cuda.set_device(args.gpu)

    # load and preprocess dataset
    g, labels, train_mask, val_mask, test_mask, evaluator = load_data_ogb(args.dataset_name, args)
    g = preprocess(g)
    features = g.ndata['feat']
    labels = labels.T.squeeze(0)

    g, labels, train_idx, val_idx, test_idx, features = map(
        lambda x: x.to(free_gpu_id), (g, labels, train_mask, val_mask, test_mask, features)
    )
    in_feats = g.ndata['feat'].shape[1]
    n_classes = labels.T.max().item() + 1
    n_edges = g.num_edges()

    g = g.to(free_gpu_id)
    
    aggr = AGGR(in_feats, n_layers=args.n_layers, activation=nn.PReLU(in_feats), dropout=args.dropout)
    if cuda:
        aggr.cuda()

    #graph power embedding reinforcement
    l_embeds, g_embeds= aggr.embed(g, features)
    embeds = (l_embeds + g_embeds).squeeze(0)
    embeds = sk_prep.normalize(X=embeds.cpu().numpy(), norm="l2")
    embeds = torch.FloatTensor(embeds).cuda()

    # create classifier model
    classifier = Classifier(embeds.shape[1], n_classes)
    if cuda:
        classifier.cuda()
    classifier_optimizer = torch.optim.AdamW(classifier.parameters(), lr=args.classifier_lr, weight_decay=args.weight_decay)
    
    dur = []
    best_acc, best_val_acc = 0, 0
    print('Testing Phase ==== Please Wait.')
    for epoch in range(args.n_classifier_epochs):
        classifier.train()
        if epoch >= 3:
            t0 = time.time()

        classifier_optimizer.zero_grad()
        preds = classifier(embeds)
        loss = F.nll_loss(preds[train_mask], labels[train_mask])
        loss.backward()
        classifier_optimizer.step()

        if epoch >= 3:
            dur.append(time.time() - t0)

        val_acc = evaluate(classifier, embeds, labels, val_mask)
        if epoch > 1000:
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                test_acc = evaluate(classifier, embeds, labels, test_mask)
                if test_acc > best_acc:
                    best_acc = test_acc

    print("Valid Accuracy {:.4f}".format(best_val_acc))
    print("Test Accuracy {:.4f}".format(best_acc))
    return best_acc


def get_free_gpu():
    os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp')
    memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()]
    return np.argmax(memory_available)


if __name__ == '__main__':
    import warnings

    warnings.filterwarnings("ignore")

    parser = argparse.ArgumentParser(description='GGD')
    register_data_args(parser)
    parser.add_argument("--n_trails", type=int, default=5, help="number of trails")
    parser.add_argument("--gpu", type=int, default=0, help="gpu")
    parser.add_argument('--data_root_dir', type=str, default='default')
    parser.add_argument('--dataset_name', type=str, default='cora', help='Dataset name: cora, citeseer, pubmed, cs, phy')
    
    parser.add_argument("--dropout", type=float, default=0., help="dropout probability")
    parser.add_argument("--drop_feat", type=float, default=0.1, help="feature dropout rate")
    parser.add_argument("--n-layers", type=int, default=1, help="number of hidden gcn layers")
    
    parser.add_argument("--classifier-lr", type=float, default=0.05, help="classifier learning rate")
    parser.add_argument("--weight-decay", type=float, default=0., help="Weight for L2 loss")
    parser.add_argument("--n-classifier-epochs", type=int, default=6000, help="number of training epochs")
    parser.add_argument("--self-loop", action='store_true', help="graph self-loop (default=False)")
    parser.set_defaults(self_loop=False)
    args = parser.parse_args()
    print(args)

    accs = []
    for i in range(args.n_trails):
        accs.append(main(args))
    mean_acc = str(np.array(accs).mean())
    print('mean accuracy:' + mean_acc)