#!/usr/bin/env python
# -*- coding: utf-8 -*-

import argparse
import os
import random
import sys
import time

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
from matplotlib.ticker import AutoMinorLocator, MultipleLocator
from models import GAT
from ogb.nodeproppred import DglNodePropPredDataset, Evaluator
from torch import nn

import dgl
import dgl.function as fn
from dgl.dataloading import (
    DataLoader,
    MultiLayerFullNeighborSampler,
    MultiLayerNeighborSampler,
)

device = None
dataset = "ogbn-proteins"
n_node_feats, n_edge_feats, n_classes = 0, 8, 112


def seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    dgl.random.seed(seed)


def load_data(dataset):
    data = DglNodePropPredDataset(name=dataset)
    evaluator = Evaluator(name=dataset)

    splitted_idx = data.get_idx_split()
    train_idx, val_idx, test_idx = (
        splitted_idx["train"],
        splitted_idx["valid"],
        splitted_idx["test"],
    )
    graph, labels = data[0]
    graph.ndata["labels"] = labels

    return graph, labels, train_idx, val_idx, test_idx, evaluator


def preprocess(graph, labels, train_idx):
    global n_node_feats

    # The sum of the weights of adjacent edges is used as node features.
    graph.update_all(
        fn.copy_e("feat", "feat_copy"), fn.sum("feat_copy", "feat")
    )
    n_node_feats = graph.ndata["feat"].shape[-1]

    # Only the labels in the training set are used as features, while others are filled with zeros.
    graph.ndata["train_labels_onehot"] = torch.zeros(
        graph.number_of_nodes(), n_classes
    )
    graph.ndata["train_labels_onehot"][train_idx, labels[train_idx, 0]] = 1
    graph.ndata["deg"] = graph.out_degrees().float().clamp(min=1)

    graph.create_formats_()

    return graph, labels


def gen_model(args):
    if args.use_labels:
        n_node_feats_ = n_node_feats + n_classes
    else:
        n_node_feats_ = n_node_feats

    model = GAT(
        n_node_feats_,
        n_edge_feats,
        n_classes,
        n_layers=args.n_layers,
        n_heads=args.n_heads,
        n_hidden=args.n_hidden,
        edge_emb=16,
        activation=F.relu,
        dropout=args.dropout,
        input_drop=args.input_drop,
        attn_drop=args.attn_drop,
        edge_drop=args.edge_drop,
        use_attn_dst=not args.no_attn_dst,
    )

    return model


def add_labels(graph, idx):
    feat = graph.srcdata["feat"]
    train_labels_onehot = torch.zeros([feat.shape[0], n_classes], device=device)
    train_labels_onehot[idx] = graph.srcdata["train_labels_onehot"][idx]
    graph.srcdata["feat"] = torch.cat([feat, train_labels_onehot], dim=-1)


def train(
    args,
    model,
    dataloader,
    _labels,
    _train_idx,
    criterion,
    optimizer,
    _evaluator,
):
    model.train()

    loss_sum, total = 0, 0

    for input_nodes, output_nodes, subgraphs in dataloader:
        subgraphs = [b.to(device) for b in subgraphs]
        new_train_idx = torch.arange(len(output_nodes), device=device)

        if args.use_labels:
            train_labels_idx = torch.arange(
                len(output_nodes), len(input_nodes), device=device
            )
            train_pred_idx = new_train_idx

            add_labels(subgraphs[0], train_labels_idx)
        else:
            train_pred_idx = new_train_idx

        pred = model(subgraphs)
        loss = criterion(
            pred[train_pred_idx],
            subgraphs[-1].dstdata["labels"][train_pred_idx].float(),
        )
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        count = len(train_pred_idx)
        loss_sum += loss.item() * count
        total += count

        # torch.cuda.empty_cache()

    return loss_sum / total


@torch.no_grad()
def evaluate(
    args,
    model,
    dataloader,
    labels,
    train_idx,
    val_idx,
    test_idx,
    criterion,
    evaluator,
):
    model.eval()

    preds = torch.zeros(labels.shape).to(device)

    # Due to the memory capacity constraints, we use sampling for inference and calculate the average of the predictions 'eval_times' times.
    eval_times = 1

    for _ in range(eval_times):
        for input_nodes, output_nodes, subgraphs in dataloader:
            subgraphs = [b.to(device) for b in subgraphs]
            new_train_idx = list(range(len(input_nodes)))

            if args.use_labels:
                add_labels(subgraphs[0], new_train_idx)

            pred = model(subgraphs)
            preds[output_nodes] += pred

            # torch.cuda.empty_cache()

    preds /= eval_times

    train_loss = criterion(preds[train_idx], labels[train_idx].float()).item()
    val_loss = criterion(preds[val_idx], labels[val_idx].float()).item()
    test_loss = criterion(preds[test_idx], labels[test_idx].float()).item()

    return (
        evaluator(preds[train_idx], labels[train_idx]),
        evaluator(preds[val_idx], labels[val_idx]),
        evaluator(preds[test_idx], labels[test_idx]),
        train_loss,
        val_loss,
        test_loss,
        preds,
    )


def run(
    args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running
):
    evaluator_wrapper = lambda pred, labels: evaluator.eval(
        {"y_pred": pred, "y_true": labels}
    )["rocauc"]

    train_batch_size = (len(train_idx) + 9) // 10
    # batch_size = len(train_idx)
    train_sampler = MultiLayerNeighborSampler(
        [32 for _ in range(args.n_layers)]
    )
    # sampler = MultiLayerFullNeighborSampler(args.n_layers)
    train_dataloader = DataLoader(
        graph.cpu(),
        train_idx.cpu(),
        train_sampler,
        batch_size=train_batch_size,
        num_workers=10,
    )

    eval_sampler = MultiLayerNeighborSampler(
        [100 for _ in range(args.n_layers)]
    )
    # sampler = MultiLayerFullNeighborSampler(args.n_layers)
    eval_dataloader = DataLoader(
        graph.cpu(),
        torch.cat([train_idx.cpu(), val_idx.cpu(), test_idx.cpu()]),
        eval_sampler,
        batch_size=65536,
        num_workers=10,
    )

    criterion = nn.BCEWithLogitsLoss()

    model = gen_model(args).to(device)

    optimizer = optim.AdamW(
        model.parameters(), lr=args.lr, weight_decay=args.wd
    )
    lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="max", factor=0.75, patience=50, verbose=True
    )

    total_time = 0
    val_score, best_val_score, final_test_score = 0, 0, 0

    train_scores, val_scores, test_scores = [], [], []
    losses, train_losses, val_losses, test_losses = [], [], [], []
    final_pred = None

    for epoch in range(1, args.n_epochs + 1):
        tic = time.time()

        loss = train(
            args,
            model,
            train_dataloader,
            labels,
            train_idx,
            criterion,
            optimizer,
            evaluator_wrapper,
        )

        toc = time.time()
        total_time += toc - tic

        if (
            epoch == args.n_epochs
            or epoch % args.eval_every == 0
            or epoch % args.log_every == 0
        ):
            (
                train_score,
                val_score,
                test_score,
                train_loss,
                val_loss,
                test_loss,
                pred,
            ) = evaluate(
                args,
                model,
                eval_dataloader,
                labels,
                train_idx,
                val_idx,
                test_idx,
                criterion,
                evaluator_wrapper,
            )

            if val_score > best_val_score:
                best_val_score = val_score
                final_test_score = test_score
                final_pred = pred

            if epoch % args.log_every == 0:
                print(
                    f"Run: {n_running}/{args.n_runs}, Epoch: {epoch}/{args.n_epochs}, Average epoch time: {total_time / epoch:.2f}s"
                )
                print(
                    f"Loss: {loss:.4f}\n"
                    f"Train/Val/Test loss: {train_loss:.4f}/{val_loss:.4f}/{test_loss:.4f}\n"
                    f"Train/Val/Test/Best val/Final test score: {train_score:.4f}/{val_score:.4f}/{test_score:.4f}/{best_val_score:.4f}/{final_test_score:.4f}"
                )

            for l, e in zip(
                [
                    train_scores,
                    val_scores,
                    test_scores,
                    losses,
                    train_losses,
                    val_losses,
                    test_losses,
                ],
                [
                    train_score,
                    val_score,
                    test_score,
                    loss,
                    train_loss,
                    val_loss,
                    test_loss,
                ],
            ):
                l.append(e)

        lr_scheduler.step(val_score)

    print("*" * 50)
    print(
        f"Best val score: {best_val_score}, Final test score: {final_test_score}"
    )
    print("*" * 50)

    if args.plot:
        fig = plt.figure(figsize=(24, 24))
        ax = fig.gca()
        ax.set_xticks(np.arange(0, args.n_epochs, 100))
        ax.set_yticks(np.linspace(0, 1.0, 101))
        ax.tick_params(labeltop=True, labelright=True)
        for y, label in zip(
            [train_scores, val_scores, test_scores],
            ["train score", "val score", "test score"],
        ):
            plt.plot(
                range(1, args.n_epochs + 1, args.log_every),
                y,
                label=label,
                linewidth=1,
            )
        ax.xaxis.set_major_locator(MultipleLocator(100))
        ax.xaxis.set_minor_locator(AutoMinorLocator(1))
        ax.yaxis.set_major_locator(MultipleLocator(0.01))
        ax.yaxis.set_minor_locator(AutoMinorLocator(2))
        plt.grid(which="major", color="red", linestyle="dotted")
        plt.grid(which="minor", color="orange", linestyle="dotted")
        plt.legend()
        plt.tight_layout()
        plt.savefig(f"gat_score_{n_running}.png")

        fig = plt.figure(figsize=(24, 24))
        ax = fig.gca()
        ax.set_xticks(np.arange(0, args.n_epochs, 100))
        ax.tick_params(labeltop=True, labelright=True)
        for y, label in zip(
            [losses, train_losses, val_losses, test_losses],
            ["loss", "train loss", "val loss", "test loss"],
        ):
            plt.plot(
                range(1, args.n_epochs + 1, args.log_every),
                y,
                label=label,
                linewidth=1,
            )
        ax.xaxis.set_major_locator(MultipleLocator(100))
        ax.xaxis.set_minor_locator(AutoMinorLocator(1))
        ax.yaxis.set_major_locator(MultipleLocator(0.1))
        ax.yaxis.set_minor_locator(AutoMinorLocator(5))
        plt.grid(which="major", color="red", linestyle="dotted")
        plt.grid(which="minor", color="orange", linestyle="dotted")
        plt.legend()
        plt.tight_layout()
        plt.savefig(f"gat_loss_{n_running}.png")

    if args.save_pred:
        os.makedirs("./output", exist_ok=True)
        torch.save(F.softmax(final_pred, dim=1), f"./output/{n_running}.pt")

    return best_val_score, final_test_score


def count_parameters(args):
    model = gen_model(args)
    return sum(
        [np.prod(p.size()) for p in model.parameters() if p.requires_grad]
    )


def main():
    global device

    argparser = argparse.ArgumentParser(
        "GAT implementation on ogbn-proteins",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    argparser.add_argument(
        "--cpu",
        action="store_true",
        help="CPU mode. This option overrides '--gpu'.",
    )
    argparser.add_argument("--gpu", type=int, default=0, help="GPU device ID")
    argparser.add_argument("--seed", type=int, default=0, help="random seed")
    argparser.add_argument(
        "--n-runs", type=int, default=10, help="running times"
    )
    argparser.add_argument(
        "--n-epochs", type=int, default=1200, help="number of epochs"
    )
    argparser.add_argument(
        "--use-labels",
        action="store_true",
        help="Use labels in the training set as input features.",
    )
    argparser.add_argument(
        "--no-attn-dst", action="store_true", help="Don't use attn_dst."
    )
    argparser.add_argument(
        "--n-heads", type=int, default=6, help="number of heads"
    )
    argparser.add_argument(
        "--lr", type=float, default=0.01, help="learning rate"
    )
    argparser.add_argument(
        "--n-layers", type=int, default=6, help="number of layers"
    )
    argparser.add_argument(
        "--n-hidden", type=int, default=80, help="number of hidden units"
    )
    argparser.add_argument(
        "--dropout", type=float, default=0.25, help="dropout rate"
    )
    argparser.add_argument(
        "--input-drop", type=float, default=0.1, help="input drop rate"
    )
    argparser.add_argument(
        "--attn-drop", type=float, default=0.0, help="attention dropout rate"
    )
    argparser.add_argument(
        "--edge-drop", type=float, default=0.1, help="edge drop rate"
    )
    argparser.add_argument("--wd", type=float, default=0, help="weight decay")
    argparser.add_argument(
        "--eval-every",
        type=int,
        default=5,
        help="evaluate every EVAL_EVERY epochs",
    )
    argparser.add_argument(
        "--log-every", type=int, default=5, help="log every LOG_EVERY epochs"
    )
    argparser.add_argument(
        "--plot", action="store_true", help="plot learning curves"
    )
    argparser.add_argument(
        "--save-pred", action="store_true", help="save final predictions"
    )
    args = argparser.parse_args()

    if args.cpu:
        device = torch.device("cpu")
    else:
        device = torch.device(f"cuda:{args.gpu}")

    # load data & preprocess
    print("Loading data")
    graph, labels, train_idx, val_idx, test_idx, evaluator = load_data(dataset)
    print("Preprocessing")
    graph, labels = preprocess(graph, labels, train_idx)

    labels, train_idx, val_idx, test_idx = map(
        lambda x: x.to(device), (labels, train_idx, val_idx, test_idx)
    )

    # run
    val_scores, test_scores = [], []

    for i in range(args.n_runs):
        print("Running", i)
        seed(args.seed + i)
        val_score, test_score = run(
            args, graph, labels, train_idx, val_idx, test_idx, evaluator, i + 1
        )
        val_scores.append(val_score)
        test_scores.append(test_score)

    print(" ".join(sys.argv))
    print(args)
    print(f"Runned {args.n_runs} times")
    print("Val scores:", val_scores)
    print("Test scores:", test_scores)
    print(f"Average val score: {np.mean(val_scores)} ± {np.std(val_scores)}")
    print(f"Average test score: {np.mean(test_scores)} ± {np.std(test_scores)}")
    print(f"Number of params: {count_parameters(args)}")


if __name__ == "__main__":
    main()

# Namespace(attn_drop=0.0, cpu=False, dropout=0.25, edge_drop=0.1, eval_every=5, gpu=6, input_drop=0.1, log_every=5, lr=0.01, n_epochs=1200, n_heads=6, n_hidden=80, n_layers=6, n_runs=10, no_attn_dst=False, plot=True, save_pred=False, seed=0, use_labels=False, wd=0)
# Runned 10 times
# Val scores: [0.927741031859485, 0.9272113161947824, 0.9271363901359605, 0.9275579074100136, 0.9264291968462317, 0.9275278541203443, 0.9286381790529751, 0.9288245051991526, 0.9269289529175155, 0.9278177920224489]
# Test scores: [0.8754403567694566, 0.8749781870941457, 0.8735933245353141, 0.8759835445000637, 0.8745950242855286, 0.8742530369108132, 0.8784892022402326, 0.873345314887444, 0.8724393129004984, 0.874077975765639]
# Average val score: 0.927581312575891 ± 0.0006953509986591492
# Average test score: 0.8747195279889135 ± 0.001593598488797452
# Number of params: 2475232

# Namespace(attn_drop=0.0, cpu=False, dropout=0.25, edge_drop=0.1, eval_every=5, gpu=7, input_drop=0.1, log_every=5, lr=0.01, n_epochs=1200, n_heads=6, n_hidden=80, n_layers=6, n_runs=10, no_attn_dst=False, plot=True, save_pred=False, seed=0, use_labels=True, wd=0)
# Runned 10 times
# Val scores: [0.9293776332568928, 0.9281066322254939, 0.9286775378440911, 0.9270252685136046, 0.9267937838323375, 0.9277731792338011, 0.9285615428437761, 0.9270819730221879, 0.9276822010553241, 0.9287115722177839]
# Test scores: [0.8761623033485811, 0.8773002619440896, 0.8756680817047869, 0.8751873860287073, 0.875781797307807, 0.8764533839446703, 0.8771202308989311, 0.8765888651476396, 0.8773581283481205, 0.8777751912293709]
# Average val score: 0.9279791324045293 ± 0.0008115348697502517
# Average test score: 0.8765395629902706 ± 0.0008016806017700173
# Number of params: 2484192
