#!/usr/bin/env python
# -*- coding: utf-8 -*-
# https://github.com/dmlc/dgl/blob/master/examples/pytorch/ogb/ogbn-arxiv/gcn.py

import argparse
import math
import time

import numpy as np
import torch as th
import torch.nn.functional as F
import torch.optim as optim
from matplotlib import pyplot as plt
from matplotlib.ticker import AutoMinorLocator, MultipleLocator
from models import GCN, GCN_sequential
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
import utils
import dgl
import os
import wandb
# wandb.init(mode="disabled") # This is for testing
device = None
in_feats, n_classes = None, None
epsilon = 1 - math.log(2)


def gen_model(args):
    if args.dataset == "ogbn-arxiv":

        if args.use_labels:
            model = GCN(
                in_feats + n_classes,
                args.n_hidden,
                n_classes,
                args.n_layers,
                F.relu,
                args.dropout,
                args.use_linear,
            )
        else:
            model = GCN(
                in_feats,
                args.n_hidden,
                n_classes,
                args.n_layers,
                F.relu,
                args.dropout,
                args.use_linear,
            )
    elif args.dataset == 'ogbn-products':
        model = GCN_sequential(
                in_feats,
                args.n_hidden,
                n_classes,
                args.n_layers,
                F.relu,
                args.dropout,
                args.use_linear,
            )
    return model


def cross_entropy(x, labels):
    y = F.cross_entropy(x, labels[:, 0], reduction="none")
    y = th.log(epsilon + y) - math.log(epsilon)
    return th.mean(y)


def compute_acc(pred, labels, evaluator):
    return evaluator.eval(
        {"y_pred": pred.argmax(dim=-1, keepdim=True), "y_true": labels}
    )["acc"]


def add_labels(feat, labels, idx):
    onehot = th.zeros([feat.shape[0], n_classes]).to(device)
    onehot[idx, labels[idx, 0]] = 1
    return th.cat([feat, onehot], dim=-1)


def adjust_learning_rate(optimizer, lr, epoch):
    if epoch <= 50:
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr * epoch / 50


def train(model, graph, labels, train_idx, optimizer, use_labels,lip_lambda):
    model.train()

    feat = graph.ndata["feat"]

    if use_labels:
        mask_rate = 0.5
        mask = th.rand(train_idx.shape) < mask_rate

        train_labels_idx = train_idx[mask]
        train_pred_idx = train_idx[~mask]

        feat = add_labels(feat, labels, train_labels_idx)
    else:
        mask_rate = 0.5
        mask = th.rand(train_idx.shape) < mask_rate

        train_pred_idx = train_idx[mask]

    optimizer.zero_grad()
    pred = model(graph, feat)
    loss = cross_entropy(pred[train_pred_idx], labels[train_pred_idx])
    loss = loss + model.regularization_loss(lip_lambda)
    loss.backward()
    optimizer.step()

    return loss, pred


@th.no_grad()
def evaluate(
    model, graph, labels, train_idx, val_idx, test_idx, use_labels, evaluator
):
    model.eval()

    feat = graph.ndata["feat"]

    if use_labels:
        feat = add_labels(feat, labels, train_idx)

    pred = model(graph, feat)
    train_loss = cross_entropy(pred[train_idx], labels[train_idx])
    val_loss = cross_entropy(pred[val_idx], labels[val_idx])
    test_loss = cross_entropy(pred[test_idx], labels[test_idx])

    return (
        compute_acc(pred[train_idx], labels[train_idx], evaluator),
        compute_acc(pred[val_idx], labels[val_idx], evaluator),
        compute_acc(pred[test_idx], labels[test_idx], evaluator),
        train_loss,
        val_loss,
        test_loss,
    )


def run(
    args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running
):
    # define model and optimizer
    model = gen_model(args)
    model = model.to(device)

    optimizer = optim.AdamW(
        model.parameters(), lr=args.lr, weight_decay=args.wd
    )
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode="min",
        factor=0.5,
        patience=100,
        verbose=True,
        min_lr=1e-3,
    )

    # training loop
    total_time = 0
    best_val_acc, final_test_acc, best_val_loss = 0, 0, float("inf")

    accs, train_accs, val_accs, test_accs = [], [], [], []
    losses, train_losses, val_losses, test_losses = [], [], [], []

    for epoch in range(1, args.n_epochs + 1):
        tic = time.time()

        adjust_learning_rate(optimizer, args.lr, epoch)

        loss, pred = train(
            model, graph, labels, train_idx, optimizer, args.use_labels,args.lip_lambda
        )
        acc = compute_acc(pred[train_idx], labels[train_idx], evaluator)

        (
            train_acc,
            val_acc,
            test_acc,
            train_loss,
            val_loss,
            test_loss,
        ) = evaluate(
            model,
            graph,
            labels,
            train_idx,
            val_idx,
            test_idx,
            args.use_labels,
            evaluator,
        )

        lr_scheduler.step(loss)

        toc = time.time()
        total_time += toc - tic

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_val_acc = val_acc
            final_test_acc = test_acc

        if epoch % args.log_every == 0:
            print(
                f"Run: {n_running}/{args.n_runs}, Epoch: {epoch}/{args.n_epochs}, Average epoch time: {total_time / epoch:.2f}\n"
                f"Loss: {loss.item():.4f}, Acc: {acc:.4f}\n"
                f"Train/Val/Test loss: {train_loss:.4f}/{val_loss:.4f}/{test_loss:.4f}\n"
                f"Train/Val/Test/Best val/Final test acc: {train_acc:.4f}/{val_acc:.4f}/{test_acc:.4f}/{best_val_acc:.4f}/{final_test_acc:.4f}"
            )
            wandb.log({"Train Accuracy": train_acc, 
                       "Validation Accuracy": val_acc,
                       "Test Accuracy": test_acc,
                       "Train Loss": train_loss,
                       "Validation Loss": val_loss,
                       "Test Loss": test_loss,
                       })


        for l, e in zip(
            [
                accs,
                train_accs,
                val_accs,
                test_accs,
                losses,
                train_losses,
                val_losses,
                test_losses,
            ],
            [
                acc,
                train_acc,
                val_acc,
                test_acc,
                loss,
                train_loss,
                val_loss,
                test_loss,
            ],
        ):
            l.append(e)

    print("*" * 50)
    print(f"Best val acc: {best_val_acc}, Final test acc: {final_test_acc}")
    print("*" * 50)

    if args.plot_curves:
        fig = plt.figure(figsize=(24, 24))
        ax = fig.gca()
        ax.set_xticks(np.arange(0, args.n_epochs, 100))
        ax.set_yticks(np.linspace(0, 1.0, 101))
        ax.tick_params(labeltop=True, labelright=True)
        for y, label in zip(
            [accs, train_accs, val_accs, test_accs],
            ["acc", "train acc", "val acc", "test acc"],
        ):
            plt.plot(range(args.n_epochs), y, label=label)
        ax.xaxis.set_major_locator(MultipleLocator(100))
        ax.xaxis.set_minor_locator(AutoMinorLocator(1))
        ax.yaxis.set_major_locator(MultipleLocator(0.01))
        ax.yaxis.set_minor_locator(AutoMinorLocator(2))
        plt.grid(which="major", color="red", linestyle="dotted")
        plt.grid(which="minor", color="orange", linestyle="dotted")
        plt.legend()
        plt.tight_layout()
        plt.savefig(f"gcn_acc_{n_running}.png")

        fig = plt.figure(figsize=(24, 24))
        ax = fig.gca()
        ax.set_xticks(np.arange(0, args.n_epochs, 100))
        ax.tick_params(labeltop=True, labelright=True)
        for y, label in zip(
            [losses, train_losses, val_losses, test_losses],
            ["loss", "train loss", "val loss", "test loss"],
        ):
            plt.plot(range(args.n_epochs), y, label=label)
        ax.xaxis.set_major_locator(MultipleLocator(100))
        ax.xaxis.set_minor_locator(AutoMinorLocator(1))
        ax.yaxis.set_major_locator(MultipleLocator(0.1))
        ax.yaxis.set_minor_locator(AutoMinorLocator(5))
        plt.grid(which="major", color="red", linestyle="dotted")
        plt.grid(which="minor", color="orange", linestyle="dotted")
        plt.legend()
        plt.tight_layout()
        plt.savefig(f"gcn_loss_{n_running}.png")

    return best_val_acc, final_test_acc


def count_parameters(args):
    model = gen_model(args)
    return sum(
        [np.prod(p.size()) for p in model.parameters() if p.requires_grad]
    )


def main():
    global device, in_feats, n_classes

    argparser = argparse.ArgumentParser(
        "GCN on OGBN-Arxiv",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    argparser.add_argument(
        "--cpu",
        action="store_true",
        help="CPU mode. This option overrides --gpu.",
    )
    argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID.")
    argparser.add_argument(
        "--n_runs", type=int, default=10, help="running times"
    )
    argparser.add_argument(
        "--n_epochs", type=int, default=1000, help="number of epochs"
    )
    argparser.add_argument(
        "--use_labels",
        action="store_true",
        help="Use labels in the training set as input features.",
    )
    argparser.add_argument(
        "--use_linear", action="store_true", help="Use linear layer."
    )
    argparser.add_argument(
        "--lr", type=float, default=0.005, help="learning rate"
    )
    argparser.add_argument(
        "--n_layers", type=int, default=3, help="number of layers"
    )
    argparser.add_argument(
        "--n_hidden", type=int, default=256, help="number of hidden units"
    )
    argparser.add_argument(
        "--dropout", type=float, default=0.5, help="dropout rate"
    )
    argparser.add_argument(
        "--lip_lambda", type=float, default=0.5, help="lip constant"
    )
    argparser.add_argument("--wd", type=float, default=0, help="weight decay")
    argparser.add_argument(
        "--log_every", type=int, default=20, help="log every LOG_EVERY epochs"
    )
    argparser.add_argument(
        "--plot_curves", action="store_true", help="plot learning curves"
    )
    argparser.add_argument("--n_partitions", type=int, default=2, help="Number of Partitions.")
    argparser.add_argument(
        "--seed", type=int, default=42, help="Random Seed"
    )
    argparser.add_argument("--partition_method", type=str, default="random",choices=['random', 'metis'] ,help="Partition method.")
    argparser.add_argument("--dataset", type=str, default="ogbn-arxiv", help="Dataset.")
    argparser.add_argument("--partition-folder", type=str, default="partitions", help="Partition Folder.")


    args = argparser.parse_args()

    if args.cpu:
        device = th.device("cpu")
    else:
        device = th.device("cuda:%d" % args.gpu)
    
    print(args)

    # load data
    if args.dataset == "ogbn-arxiv":
        dataset = DglNodePropPredDataset(name="ogbn-arxiv")
    if args.dataset == "ogbn-products":
        dataset = DglNodePropPredDataset(name="ogbn-products")
    if args.dataset == "ogbn-mag":
        dataset = DglNodePropPredDataset(name="ogbn-mag")
    else:
        AssertionError("Dataset not supported yet.")
    
    evaluator = Evaluator(name="ogbn-arxiv")

    splitted_idx = dataset.get_idx_split()
    train_idx, val_idx, test_idx = (
        splitted_idx["train"],
        splitted_idx["valid"],
        splitted_idx["test"],
    )

    if args.n_partitions  > 1:
        if args.partition_method == 'random':

            n_nodes_train = len(train_idx)
            train_idx = train_idx[np.random.permutation(n_nodes_train)[:n_nodes_train//args.n_partitions]]
            print(len(train_idx),n_nodes_train)

    wandb.init(
        # set the wandb project where this run will be logged
        project="Arxiv GCN Partition test",
        name = f"{args.dataset}_{args.partition_method}_part_{args.n_partitions}_nodes_{len(train_idx)}_lr_{args.lr}_layers_{args.n_layers}_hidden_{args.n_hidden}_epoch_{args.n_epochs}_runs_{args.n_runs}_seed_{args.seed}_lip_lambda_{args.lip_lambda}",
        # set the name of the run
        # track hyperparameters and run metadata
        config=args
    )
    
    graph, labels = dataset[0]

    # add reverse edges
    srcs, dsts = graph.all_edges()
    graph.add_edges(dsts, srcs)

    # add self-loop
    print(f"Total edges before adding self-loop {graph.num_edges()}")
    graph = graph.remove_self_loop().add_self_loop()
    print(f"Total edges after adding self-loop {graph.num_edges()}")

    in_feats = graph.ndata["feat"].shape[1]
    n_classes = (labels.max() + 1).item()
    graph.create_formats_()

    train_idx = train_idx.to(device)
    val_idx = val_idx.to(device)
    test_idx = test_idx.to(device)
    labels = labels.to(device)
    graph = graph.to(device)

    # run
    val_accs = []
    test_accs = []

    for i in range(args.n_runs):
        val_acc, test_acc = run(
            args, graph, labels, train_idx, val_idx, test_idx, evaluator, i
        )
        val_accs.append(val_acc)
        test_accs.append(test_acc)

    print(f"Runned {args.n_runs} times")
    print("Val Accs:", val_accs)
    print("Test Accs:", test_accs)
    print(f"Average val accuracy: {np.mean(val_accs)} ± {np.std(val_accs)}")
    print(f"Average test accuracy: {np.mean(test_accs)} ± {np.std(test_accs)}")
    print(f"Number of params: {count_parameters(args)}")


if __name__ == "__main__":
    main()
    wandb.finish()
