""" The main file to train a JKNet model using a full graph """

import argparse
import copy

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from model import JKNet
from sklearn.model_selection import train_test_split
from tqdm import trange

from dgl.data import CiteseerGraphDataset, CoraGraphDataset


def main(args):
    # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
    # Load from DGL dataset
    if args.dataset == "Cora":
        dataset = CoraGraphDataset()
    elif args.dataset == "Citeseer":
        dataset = CiteseerGraphDataset()
    else:
        raise ValueError("Dataset {} is invalid.".format(args.dataset))

    graph = dataset[0]

    # check cuda
    device = (
        f"cuda:{args.gpu}"
        if args.gpu >= 0 and torch.cuda.is_available()
        else "cpu"
    )

    # retrieve the number of classes
    n_classes = dataset.num_classes

    # retrieve labels of ground truth
    labels = graph.ndata.pop("label").to(device).long()

    # Extract node features
    feats = graph.ndata.pop("feat").to(device)
    n_features = feats.shape[-1]

    # create masks for train / validation / test
    # train : val : test = 6 : 2 : 2
    n_nodes = graph.num_nodes()
    idx = torch.arange(n_nodes).to(device)
    train_idx, test_idx = train_test_split(idx, test_size=0.2)
    train_idx, val_idx = train_test_split(train_idx, test_size=0.25)

    graph = graph.to(device)

    # Step 2: Create model =================================================================== #
    model = JKNet(
        in_dim=n_features,
        hid_dim=args.hid_dim,
        out_dim=n_classes,
        num_layers=args.num_layers,
        mode=args.mode,
        dropout=args.dropout,
    ).to(device)

    best_model = copy.deepcopy(model)

    # Step 3: Create training components ===================================================== #
    loss_fn = nn.CrossEntropyLoss()
    opt = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.lamb)

    # Step 4: training epochs =============================================================== #
    acc = 0
    epochs = trange(args.epochs, desc="Accuracy & Loss")

    for _ in epochs:
        # Training using a full graph
        model.train()

        logits = model(graph, feats)

        # compute loss
        train_loss = loss_fn(logits[train_idx], labels[train_idx])
        train_acc = torch.sum(
            logits[train_idx].argmax(dim=1) == labels[train_idx]
        ).item() / len(train_idx)

        # backward
        opt.zero_grad()
        train_loss.backward()
        opt.step()

        # Validation using a full graph
        model.eval()

        with torch.no_grad():
            valid_loss = loss_fn(logits[val_idx], labels[val_idx])
            valid_acc = torch.sum(
                logits[val_idx].argmax(dim=1) == labels[val_idx]
            ).item() / len(val_idx)

        # Print out performance
        epochs.set_description(
            "Train Acc {:.4f} | Train Loss {:.4f} | Val Acc {:.4f} | Val loss {:.4f}".format(
                train_acc, train_loss.item(), valid_acc, valid_loss.item()
            )
        )

        if valid_acc > acc:
            acc = valid_acc
            best_model = copy.deepcopy(model)

    best_model.eval()
    logits = best_model(graph, feats)
    test_acc = torch.sum(
        logits[test_idx].argmax(dim=1) == labels[test_idx]
    ).item() / len(test_idx)

    print("Test Acc {:.4f}".format(test_acc))
    return test_acc


if __name__ == "__main__":
    """
    JKNet Hyperparameters
    """
    parser = argparse.ArgumentParser(description="JKNet")

    # data source params
    parser.add_argument(
        "--dataset", type=str, default="Cora", help="Name of dataset."
    )
    # cuda params
    parser.add_argument(
        "--gpu", type=int, default=-1, help="GPU index. Default: -1, using CPU."
    )
    # training params
    parser.add_argument("--run", type=int, default=10, help="Running times.")
    parser.add_argument(
        "--epochs", type=int, default=500, help="Training epochs."
    )
    parser.add_argument(
        "--lr", type=float, default=0.005, help="Learning rate."
    )
    parser.add_argument("--lamb", type=float, default=0.0005, help="L2 reg.")
    # model params
    parser.add_argument(
        "--hid-dim", type=int, default=32, help="Hidden layer dimensionalities."
    )
    parser.add_argument(
        "--num-layers", type=int, default=5, help="Number of GCN layers."
    )
    parser.add_argument(
        "--mode",
        type=str,
        default="cat",
        help="Type of aggregation.",
        choices=["cat", "max", "lstm"],
    )
    parser.add_argument(
        "--dropout",
        type=float,
        default=0.5,
        help="Dropout applied at all layers.",
    )

    args = parser.parse_args()
    print(args)

    acc_lists = []

    for _ in range(args.run):
        acc_lists.append(main(args))

    mean = np.around(np.mean(acc_lists, axis=0), decimals=3)
    std = np.around(np.std(acc_lists, axis=0), decimals=3)
    print("total acc: ", acc_lists)
    print("mean", mean)
    print("std", std)
