import argparse
import time

import networkx as nx
import numpy as np
import tensorflow as tf
from dgi import DGI, Classifier
from tensorflow.keras import layers

import dgl
from dgl.data import (
    CiteseerGraphDataset,
    CoraGraphDataset,
    PubmedGraphDataset,
    register_data_args,
)


def evaluate(model, features, labels, mask):
    logits = model(features, training=False)
    logits = logits[mask]
    labels = labels[mask]
    indices = tf.math.argmax(logits, axis=1)
    acc = tf.reduce_mean(tf.cast(indices == labels, dtype=tf.float32))
    return acc.numpy().item()


def main(args):
    # load and preprocess dataset
    if args.dataset == "cora":
        data = CoraGraphDataset()
    elif args.dataset == "citeseer":
        data = CiteseerGraphDataset()
    elif args.dataset == "pubmed":
        data = PubmedGraphDataset()
    else:
        raise ValueError("Unknown dataset: {}".format(args.dataset))

    g = data[0]
    if args.gpu < 0:
        device = "/cpu:0"
    else:
        device = "/gpu:{}".format(args.gpu)
        g = g.to(device)

    with tf.device(device):
        features = g.ndata["feat"]
        labels = g.ndata["label"]
        train_mask = g.ndata["train_mask"]
        val_mask = g.ndata["val_mask"]
        test_mask = g.ndata["test_mask"]
        in_feats = features.shape[1]
        n_classes = data.num_labels
        n_edges = data.graph.number_of_edges()

        # add self loop
        if args.self_loop:
            g = dgl.remove_self_loop(g)
            g = dgl.add_self_loop(g)
        n_edges = g.number_of_edges()

        # create DGI model
        dgi = DGI(
            g,
            in_feats,
            args.n_hidden,
            args.n_layers,
            tf.keras.layers.PReLU(
                alpha_initializer=tf.constant_initializer(0.25)
            ),
            args.dropout,
        )

        dgi_optimizer = tf.keras.optimizers.Adam(learning_rate=args.dgi_lr)

        # train deep graph infomax
        cnt_wait = 0
        best = 1e9
        best_t = 0
        dur = []
        for epoch in range(args.n_dgi_epochs):
            if epoch >= 3:
                t0 = time.time()

            with tf.GradientTape() as tape:
                loss = dgi(features)
                # Manually Weight Decay
                # We found Tensorflow has a different implementation on weight decay
                # of Adam(W) optimizer with PyTorch. And this results in worse results.
                # Manually adding weights to the loss to do weight decay solves this problem.
                for weight in dgi.trainable_weights:
                    loss = loss + args.weight_decay * tf.nn.l2_loss(weight)
                grads = tape.gradient(loss, dgi.trainable_weights)
                dgi_optimizer.apply_gradients(zip(grads, dgi.trainable_weights))

            if loss < best:
                best = loss
                best_t = epoch
                cnt_wait = 0
                dgi.save_weights("best_dgi.pkl")
            else:
                cnt_wait += 1

            if cnt_wait == args.patience:
                print("Early stopping!")
                break

            if epoch >= 3:
                dur.append(time.time() - t0)

            print(
                "Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | "
                "ETputs(KTEPS) {:.2f}".format(
                    epoch,
                    np.mean(dur),
                    loss.numpy().item(),
                    n_edges / np.mean(dur) / 1000,
                )
            )

        # create classifier model
        classifier = Classifier(args.n_hidden, n_classes)

        classifier_optimizer = tf.keras.optimizers.Adam(
            learning_rate=args.classifier_lr
        )

        # train classifier
        print("Loading {}th epoch".format(best_t))
        dgi.load_weights("best_dgi.pkl")
        embeds = dgi.encoder(features, corrupt=False)
        embeds = tf.stop_gradient(embeds)
        dur = []
        loss_fcn = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True
        )
        for epoch in range(args.n_classifier_epochs):
            if epoch >= 3:
                t0 = time.time()
            with tf.GradientTape() as tape:
                preds = classifier(embeds)
                loss = loss_fcn(labels[train_mask], preds[train_mask])
                # Manually Weight Decay
                # We found Tensorflow has a different implementation on weight decay
                # of Adam(W) optimizer with PyTorch. And this results in worse results.
                # Manually adding weights to the loss to do weight decay solves this problem.
                # In original code, there's no weight decay applied in this part
                # link: https://github.com/PetarV-/DGI/blob/master/execute.py#L121
                # for weight in classifier.trainable_weights:
                #     loss = loss + \
                #         args.weight_decay * tf.nn.l2_loss(weight)
                grads = tape.gradient(loss, classifier.trainable_weights)
                classifier_optimizer.apply_gradients(
                    zip(grads, classifier.trainable_weights)
                )
            if epoch >= 3:
                dur.append(time.time() - t0)

            acc = evaluate(classifier, embeds, labels, val_mask)
            print(
                "Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
                "ETputs(KTEPS) {:.2f}".format(
                    epoch,
                    np.mean(dur),
                    loss.numpy().item(),
                    acc,
                    n_edges / np.mean(dur) / 1000,
                )
            )

        print()
        acc = evaluate(classifier, embeds, labels, test_mask)
        print("Test Accuracy {:.4f}".format(acc))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="DGI")
    register_data_args(parser)
    parser.add_argument(
        "--dropout", type=float, default=0.0, help="dropout probability"
    )
    parser.add_argument("--gpu", type=int, default=-1, help="gpu")
    parser.add_argument(
        "--dgi-lr", type=float, default=1e-3, help="dgi learning rate"
    )
    parser.add_argument(
        "--classifier-lr",
        type=float,
        default=1e-2,
        help="classifier learning rate",
    )
    parser.add_argument(
        "--n-dgi-epochs",
        type=int,
        default=300,
        help="number of training epochs",
    )
    parser.add_argument(
        "--n-classifier-epochs",
        type=int,
        default=300,
        help="number of training epochs",
    )
    parser.add_argument(
        "--n-hidden", type=int, default=512, help="number of hidden gcn units"
    )
    parser.add_argument(
        "--n-layers", type=int, default=1, help="number of hidden gcn layers"
    )
    parser.add_argument(
        "--weight-decay", type=float, default=0.0, help="Weight for L2 loss"
    )
    parser.add_argument(
        "--patience", type=int, default=20, help="early stop patience condition"
    )
    parser.add_argument(
        "--self-loop",
        action="store_true",
        help="graph self-loop (default=False)",
    )
    parser.set_defaults(self_loop=False)
    args = parser.parse_args()
    print(args)

    main(args)
