import argparse

import torch as th
from evaluate_embedding import evaluate_embedding
from model import InfoGraph

import dgl
from dgl.data import GINDataset
from dgl.dataloading import GraphDataLoader


def argument():
    parser = argparse.ArgumentParser(description="InfoGraph")
    # data source params
    parser.add_argument(
        "--dataname", type=str, default="MUTAG", help="Name of dataset."
    )

    # training params
    parser.add_argument(
        "--gpu", type=int, default=-1, help="GPU index, default:-1, using CPU."
    )
    parser.add_argument(
        "--epochs", type=int, default=20, help="Training epochs."
    )
    parser.add_argument(
        "--batch_size", type=int, default=128, help="Training batch size."
    )
    parser.add_argument("--lr", type=float, default=0.01, help="Learning rate.")
    parser.add_argument(
        "--log_interval",
        type=int,
        default=1,
        help="Interval between two evaluations.",
    )

    # model params
    parser.add_argument(
        "--n_layers",
        type=int,
        default=3,
        help="Number of graph convolution layers before each pooling.",
    )
    parser.add_argument(
        "--hid_dim", type=int, default=32, help="Hidden layer dimensionalities."
    )

    args = parser.parse_args()

    # check cuda
    if args.gpu != -1 and th.cuda.is_available():
        args.device = "cuda:{}".format(args.gpu)
    else:
        args.device = "cpu"

    return args


def collate(samples):
    """collate function for building graph dataloader"""

    graphs, labels = map(list, zip(*samples))

    # generate batched graphs and labels
    batched_graph = dgl.batch(graphs)
    batched_labels = th.tensor(labels)

    n_graphs = len(graphs)
    graph_id = th.arange(n_graphs)
    graph_id = dgl.broadcast_nodes(batched_graph, graph_id)

    batched_graph.ndata["graph_id"] = graph_id

    return batched_graph, batched_labels


if __name__ == "__main__":

    # Step 1: Prepare graph data   ===================================== #
    args = argument()
    print(args)

    # load dataset from dgl.data.GINDataset
    dataset = GINDataset(args.dataname, self_loop=False)

    # get graphs and labels
    graphs, labels = map(list, zip(*dataset))

    # generate a full-graph with all examples for evaluation
    wholegraph = dgl.batch(graphs)
    wholegraph.ndata["attr"] = wholegraph.ndata["attr"].to(th.float32)

    # create dataloader for batch training
    dataloader = GraphDataLoader(
        dataset,
        batch_size=args.batch_size,
        collate_fn=collate,
        drop_last=False,
        shuffle=True,
    )

    in_dim = wholegraph.ndata["attr"].shape[1]

    # Step 2: Create model =================================================================== #
    model = InfoGraph(in_dim, args.hid_dim, args.n_layers)
    model = model.to(args.device)

    # Step 3: Create training components ===================================================== #
    optimizer = th.optim.Adam(model.parameters(), lr=args.lr)

    print("===== Before training ======")

    wholegraph = wholegraph.to(args.device)
    wholefeat = wholegraph.ndata["attr"]

    emb = model.get_embedding(wholegraph, wholefeat).cpu()
    res = evaluate_embedding(emb, labels, args.device)

    """ Evaluate the initialized embeddings """
    """ using logistic regression and SVM(non-linear) """
    print("logreg {:4f}, svc {:4f}".format(res[0], res[1]))

    best_logreg = 0
    best_logreg_epoch = 0
    best_svc = 0
    best_svc_epoch = 0

    # Step 4: training epochs =============================================================== #
    for epoch in range(args.epochs):
        loss_all = 0
        model.train()

        for graph, label in dataloader:

            graph = graph.to(args.device)
            feat = graph.ndata["attr"]
            graph_id = graph.ndata["graph_id"]

            n_graph = label.shape[0]

            optimizer.zero_grad()
            loss = model(graph, feat, graph_id)
            loss.backward()
            optimizer.step()
            loss_all += loss.item()

        print("Epoch {}, Loss {:.4f}".format(epoch, loss_all))

        if epoch % args.log_interval == 0:

            # evaluate embeddings
            model.eval()
            emb = model.get_embedding(wholegraph, wholefeat).cpu()
            res = evaluate_embedding(emb, labels, args.device)

            if res[0] > best_logreg:
                best_logreg = res[0]
                best_logreg_epoch = epoch

            if res[1] > best_svc:
                best_svc = res[1]
                best_svc_epoch = epoch

            print(
                "best logreg {:4f}, epoch {} | best svc: {:4f}, epoch {}".format(
                    best_logreg, best_logreg_epoch, best_svc, best_svc_epoch
                )
            )

    print("Training End")
    print("best logreg {:4f} ,best svc {:4f}".format(best_logreg, best_svc))
