import argparse
import random
import copy
import time
import os
import numpy as np
from sklearn import preprocessing as sk_prep

import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
from dgl import DGLGraph
from dgl.data import register_data_args
from dgl.nn import EdgeWeightNorm
from dgl.nn.pytorch import GraphConv
import dgl.function as fn
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator

from gcn_aggr import GraphConvAGGR


class GCNAGGR(nn.Module):
    def __init__(self, in_feats, n_layers, activation, dropout):
        super(GCNAGGR, self).__init__()
        self.layers = nn.ModuleList()
        self.bns = torch.nn.ModuleList()
        self.layers.append(GraphConvAGGR(in_feats, activation=activation))
        self.bns.append(torch.nn.BatchNorm1d(in_feats, momentum = 0.01))
        for i in range(n_layers - 1):
            self.layers.append(GraphConvAGGR(in_feats, activation=activation))
            self.bns.append(torch.nn.BatchNorm1d(in_feats, momentum = 0.01))
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, g, features):
        h = features
        for i, layer in enumerate(self.layers):
            if i != 0:
                h = self.dropout(h)
            h = layer(g, h)
        return h
    

class GCN(nn.Module):
    def __init__(self, in_feats, n_hidden, n_classes, n_layers, activation, dropout, bias = True, weight=True):
        super(GCN, self).__init__()
        self.layers = nn.ModuleList()
        self.bns = torch.nn.ModuleList()
        self.layers.append(GraphConv(in_feats, n_hidden, weight = weight, bias = bias, activation=activation))
        self.bns.append(torch.nn.BatchNorm1d(n_hidden, momentum = 0.01))
        for i in range(n_layers - 1):
            self.layers.append(GraphConv(n_hidden, n_hidden, weight=weight, bias=bias, activation=activation))
            self.bns.append(torch.nn.BatchNorm1d(n_hidden, momentum = 0.01))
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, g, features):
        h = features
        for i, layer in enumerate(self.layers):
            if i != 0:
                h = self.dropout(h)
            h = layer(g, h)
        return h


class ChebNetII(nn.Module):
    def __init__(self, K):
        super(ChebNetII, self).__init__()
        self.K = K

    def chebyshev_polynomials(self, g, x):
        Tx_0 = x
        Tx_1 = self.propagate(g, x)
        out = Tx_0 + Tx_1
        for k in range(2, self.K + 1):
            Tx_2 = 2 * self.propagate(g, Tx_1) - Tx_0
            out = out + Tx_2
            Tx_0, Tx_1 = Tx_1, Tx_2
        return out

    def propagate(self, g, x):
        g.ndata['h'] = x
        g.update_all(dgl.function.copy_u('h', 'm'), dgl.function.sum('m', 'h_new'))
        return g.ndata.pop('h_new')

    def forward(self, g, x):
        x = self.chebyshev_polynomials(g, x)
        return F.relu(x)

    
class Encoder(nn.Module):
    def __init__(self, g, in_feats, n_hidden, n_layers, dropout, gnn_encoder, k = 1):
        super(Encoder, self).__init__()
        self.g = g
        self.gnn_encoder = gnn_encoder
        if gnn_encoder == 'gcn':
            activation = nn.PReLU(args.n_hidden)
            self.conv = GCN(in_feats, n_hidden, n_hidden, n_layers, activation, dropout)
        elif gnn_encoder == 'gcn_aggr':
            activation = nn.PReLU(in_feats)
            self.conv = GCNAGGR(in_feats, n_layers, activation, dropout)
        elif gnn_encoder == 'chebnet':
            activation = nn.PReLU(in_feats)
            self.conv = ChebNetII(k)

    def forward(self, features, corrupt=False):
        if corrupt:
            perm = torch.randperm(self.g.number_of_nodes())
            features = features[perm]
            features = self.conv(self.g, features)
        return features


class GGD(nn.Module):
    def __init__(self, g, in_feats, n_hidden, n_layers, dropout, proj_layers, gnn_encoder, num_hop):
        super(GGD, self).__init__()
        self.encoder = Encoder(g, in_feats, n_hidden, n_layers, dropout, gnn_encoder, num_hop)
        self.mlp = torch.nn.ModuleList()
        for i in range(proj_layers):
            if i == 0:
                self.mlp.append(nn.Linear(in_feats if gnn_encoder == 'gcn_aggr' else n_hidden, n_hidden))
            else:
                self.mlp.append(nn.Linear(n_hidden, n_hidden))
        self.loss = nn.BCEWithLogitsLoss()

    def forward(self, features, labels, loss_func):
        h_1 = self.encoder(features, corrupt=False)
        h_2 = self.encoder(features, corrupt=True)

        sc_1 = h_1.squeeze(0)
        sc_2 = h_2.squeeze(0)
        for i, lin in enumerate(self.mlp):
            sc_1 = lin(sc_1)
            sc_2 = lin(sc_2)

        sc_1 = sc_1.sum(1).unsqueeze(0)
        sc_2 = sc_2.sum(1).unsqueeze(0)

        logits = torch.cat((sc_1, sc_2), 1)
        loss = loss_func(logits, labels)
        return loss

    def embed(self, features, g):
        h_1 = self.encoder(features, corrupt=False)
        feat = h_1.clone().squeeze(0)

        degs = g.in_degrees().float().clamp(min=1)
        norm = torch.pow(degs, -0.5)
        norm = norm.to(h_1.device).unsqueeze(1)
        for _ in range(10):
            feat = feat * norm
            g.ndata['h2'] = feat
            g.update_all(fn.copy_u('h2', 'm'), fn.sum('m', 'h2'))
            feat = g.ndata.pop('h2')
            feat = feat * norm

        h_2 = feat.unsqueeze(0)
        return h_1.detach(), h_2.detach()


class Classifier(nn.Module):
    def __init__(self, n_hidden, n_classes):
        super(Classifier, self).__init__()
        self.fc = nn.Linear(n_hidden, n_classes)
        self.reset_parameters()

    def reset_parameters(self):
        self.fc.reset_parameters()

    def forward(self, features):
        features = self.fc(features)
        return torch.log_softmax(features, dim=-1)
    

def aug_feature_dropout(input_feat, drop_percent=0.2):
    # aug_input_feat = copy.deepcopy((input_feat.squeeze(0)))
    aug_input_feat = copy.deepcopy(input_feat)
    drop_feat_num = int(aug_input_feat.shape[1] * drop_percent)
    drop_idx = random.sample([i for i in range(aug_input_feat.shape[1])], drop_feat_num)
    aug_input_feat[:, drop_idx] = 0
    return aug_input_feat


def evaluate(model, features, labels, mask):
    model.eval()
    with torch.no_grad():
        logits = model(features)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)


def load_data_ogb(dataset, args):
    global n_node_feats, n_classes

    if args.data_root_dir == 'default':
        data = DglNodePropPredDataset(name=dataset)
    else:
        data = DglNodePropPredDataset(name=dataset, root=args.data_root_dir)

    evaluator = Evaluator(name=dataset)

    splitted_idx = data.get_idx_split()
    train_idx, val_idx, test_idx = splitted_idx["train"], splitted_idx["valid"], splitted_idx["test"]
    graph, labels = data[0]

    n_node_feats = graph.ndata["feat"].shape[1]
    n_classes = (labels.max() + 1).item()
    return graph, labels, train_idx, val_idx, test_idx, evaluator


def preprocess(graph):
    global n_node_feats

    # make bidirected
    feat = graph.ndata["feat"]
    graph = dgl.to_bidirected(graph)
    graph.ndata["feat"] = feat

    # add self-loop
    print(f"Total edges before adding self-loop {graph.number_of_edges()}")
    graph = graph.remove_self_loop().add_self_loop()
    print(f"Total edges after adding self-loop {graph.number_of_edges()}")

    graph.create_formats_()
    return graph


def main(args):
    cuda = True
    free_gpu_id = int(args.gpu)
    torch.cuda.set_device(args.gpu)

    # load and preprocess dataset
    g, labels, train_mask, val_mask, test_mask, evaluator = load_data_ogb(args.dataset_name, args)
    g = preprocess(g)
    features = g.ndata['feat']
    labels = labels.T.squeeze(0)

    g, labels, train_idx, val_idx, test_idx, features = map(
        lambda x: x.to(free_gpu_id), (g, labels, train_mask, val_mask, test_mask, features)
    )
    in_feats = g.ndata['feat'].shape[1]
    n_classes = labels.T.max().item() + 1
    n_edges = g.num_edges()

    g = g.to(free_gpu_id)
    # create GGD model
    ggd = GGD(g,
              in_feats,
              args.n_hidden,
              args.n_layers,
              args.dropout,
              args.proj_layers,
              args.gnn_encoder,
              args.num_hop)
    if cuda:
        ggd.cuda()

    ggd_optimizer = torch.optim.AdamW(ggd.parameters(), lr=args.ggd_lr, weight_decay=args.weight_decay)
    b_xent = nn.BCEWithLogitsLoss()

    # train deep graph infomax
    cnt_wait = 0
    best = 1e9
    best_t = 0
    counts = 0
    dur = []
    tag = str(int(np.random.random() * 10000000000))
    for epoch in range(args.n_ggd_epochs):
        ggd.train()
        if epoch >= 3:
            t0 = time.time()

        ggd_optimizer.zero_grad()

        lbl_1 = torch.ones(1, g.num_nodes())
        lbl_2 = torch.zeros(1, g.num_nodes())
        lbl = torch.cat((lbl_1, lbl_2), 1).cuda()

        aug_feat = aug_feature_dropout(features, args.drop_feat)
        loss = ggd(aug_feat.cuda(), lbl, b_xent)
        loss.backward()
        ggd_optimizer.step()

        if loss < best:
            best = loss
            best_t = epoch
            cnt_wait = 0
            torch.save(ggd.state_dict(), 'pkl/best_ggd' + tag + '.pkl')
        else:
            cnt_wait += 1

        if cnt_wait == args.patience:
            print('Early stopping!')
            break

        if epoch >= 3:
            dur.append(time.time() - t0)

        print(f"Epoch {epoch:05d} | Time(s) {np.mean(dur):.4f} | Loss {loss.item():.4f} | ETputs(KTEPS) {n_edges / np.mean(dur) / 1000:.2f}")
        counts += 1

    print('Training Completed.')

    # train classifier
    print('Loading {}th epoch'.format(best_t))
    ggd.load_state_dict(torch.load('pkl/best_ggd' + tag + '.pkl'))

    #graph power embedding reinforcement
    l_embeds, g_embeds= ggd.embed(features, g)
    embeds = (l_embeds + g_embeds).squeeze(0)
    embeds = sk_prep.normalize(X=embeds.cpu().numpy(), norm="l2")
    embeds = torch.FloatTensor(embeds).cuda()

    # create classifier model
    classifier = Classifier(embeds.shape[1], n_classes)
    if cuda:
        classifier.cuda()
    classifier_optimizer = torch.optim.AdamW(classifier.parameters(), lr=args.classifier_lr, weight_decay=args.weight_decay)
    
    dur = []
    best_acc, best_val_acc = 0, 0
    print('Testing Phase ==== Please Wait.')
    for epoch in range(args.n_classifier_epochs):
        classifier.train()
        if epoch >= 3:
            t0 = time.time()

        classifier_optimizer.zero_grad()
        preds = classifier(embeds)
        loss = F.nll_loss(preds[train_mask], labels[train_mask])
        loss.backward()
        classifier_optimizer.step()

        if epoch >= 3:
            dur.append(time.time() - t0)

        val_acc = evaluate(classifier, embeds, labels, val_mask)
        if epoch > 1000:
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                test_acc = evaluate(classifier, embeds, labels, test_mask)
                if test_acc > best_acc:
                    best_acc = test_acc

    print("Valid Accuracy {:.4f}".format(best_val_acc))
    print("Test Accuracy {:.4f}".format(best_acc))
    return best_acc


def get_free_gpu():
    os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp')
    memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()]
    return np.argmax(memory_available)


if __name__ == '__main__':
    import warnings

    warnings.filterwarnings("ignore")

    parser = argparse.ArgumentParser(description='GGD')
    register_data_args(parser)
    parser.add_argument("--dropout", type=float, default=0.,
                        help="dropout probability")
    parser.add_argument("--gpu", type=int, default=0,
                        help="gpu")
    parser.add_argument("--ggd-lr", type=float, default=0.001,
                        help="ggd learning rate")
    parser.add_argument("--drop_feat", type=float, default=0.1,
                        help="feature dropout rate")
    parser.add_argument("--classifier-lr", type=float, default=0.05,
                        help="classifier learning rate")
    parser.add_argument("--n-ggd-epochs", type=int, default=500,
                        help="number of training epochs")
    parser.add_argument("--n-classifier-epochs", type=int, default=6000,
                        help="number of training epochs")
    parser.add_argument("--n-hidden", type=int, default=512,
                        help="number of hidden gcn units")
    parser.add_argument("--proj_layers", type=int, default=1,
                        help="number of project linear layers")
    parser.add_argument("--n-layers", type=int, default=1,
                        help="number of hidden gcn layers")
    parser.add_argument("--weight-decay", type=float, default=0.,
                        help="Weight for L2 loss")
    parser.add_argument("--patience", type=int, default=500,
                        help="early stop patience condition")
    parser.add_argument("--self-loop", action='store_true',
                        help="graph self-loop (default=False)")
    parser.add_argument("--n_trails", type=int, default=5,
                        help="number of trails")
    parser.add_argument("--gnn_encoder", type=str, default='gcn',
                        help="choice of gnn encoder")
    parser.add_argument("--num_hop", type=int, default=10,
                        help="number of k for sgc")
    parser.add_argument('--data_root_dir', type=str, default='default',
                           help="dir_path for saving graph data. Note that this model use DGL loader so do not mix up with the dir_path for the Pyg one. Use 'default' to save datasets at current folder.")
    parser.add_argument('--dataset_name', type=str, default='cora',
                        help='Dataset name: cora, citeseer, pubmed, cs, phy')
    parser.set_defaults(self_loop=False)
    args = parser.parse_args()
    print(args)

    accs = []
    for i in range(args.n_trails):
        accs.append(main(args))
    mean_acc = str(np.array(accs).mean())
    print('mean accuracy:' + mean_acc)

    # file_name = str(args.dataset_name)
    # f = open('result/' + 'result_' + file_name + '.txt', 'a')
    # f.write(str(args) + '\n')
    # f.write(mean_acc + '\n')