import argparse
import os
import random
import time

import networkx as nx
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from data_utils import pre_process
from model.encoder import DiffPool

import dgl
import dgl.function as fn
from dgl import DGLGraph
from dgl.data import tu

global_train_time_per_epoch = []


def arg_parse():
    """
    argument parser
    """
    parser = argparse.ArgumentParser(description="DiffPool arguments")
    parser.add_argument("--dataset", dest="dataset", help="Input Dataset")
    parser.add_argument(
        "--pool_ratio", dest="pool_ratio", type=float, help="pooling ratio"
    )
    parser.add_argument(
        "--num_pool", dest="num_pool", type=int, help="num_pooling layer"
    )
    parser.add_argument(
        "--no_link_pred",
        dest="linkpred",
        action="store_false",
        help="switch of link prediction object",
    )
    parser.add_argument("--cuda", dest="cuda", type=int, help="switch cuda")
    parser.add_argument("--lr", dest="lr", type=float, help="learning rate")
    parser.add_argument(
        "--clip", dest="clip", type=float, help="gradient clipping"
    )
    parser.add_argument(
        "--batch-size", dest="batch_size", type=int, help="batch size"
    )
    parser.add_argument("--epochs", dest="epoch", type=int, help="num-of-epoch")
    parser.add_argument(
        "--train-ratio",
        dest="train_ratio",
        type=float,
        help="ratio of trainning dataset split",
    )
    parser.add_argument(
        "--test-ratio",
        dest="test_ratio",
        type=float,
        help="ratio of testing dataset split",
    )
    parser.add_argument(
        "--num_workers",
        dest="n_worker",
        type=int,
        help="number of workers when dataloading",
    )
    parser.add_argument(
        "--gc-per-block",
        dest="gc_per_block",
        type=int,
        help="number of graph conv layer per block",
    )
    parser.add_argument(
        "--bn",
        dest="bn",
        action="store_const",
        const=True,
        default=True,
        help="switch for bn",
    )
    parser.add_argument(
        "--dropout", dest="dropout", type=float, help="dropout rate"
    )
    parser.add_argument(
        "--bias",
        dest="bias",
        action="store_const",
        const=True,
        default=True,
        help="switch for bias",
    )
    parser.add_argument(
        "--save_dir",
        dest="save_dir",
        help="model saving directory: SAVE_DICT/DATASET",
    )
    parser.add_argument(
        "--load_epoch",
        dest="load_epoch",
        type=int,
        help="load trained model params from\
                         SAVE_DICT/DATASET/model-LOAD_EPOCH",
    )
    parser.add_argument(
        "--data_mode",
        dest="data_mode",
        help="data\
                        preprocessing mode: default, id, degree, or one-hot\
                        vector of degree number",
        choices=["default", "id", "deg", "deg_num"],
    )

    parser.set_defaults(
        dataset="ENZYMES",
        pool_ratio=0.15,
        num_pool=1,
        cuda=1,
        lr=1e-3,
        clip=2.0,
        batch_size=20,
        epoch=4000,
        train_ratio=0.7,
        test_ratio=0.1,
        n_worker=1,
        gc_per_block=3,
        dropout=0.0,
        method="diffpool",
        bn=True,
        bias=True,
        save_dir="./model_param",
        load_epoch=-1,
        data_mode="default",
    )
    return parser.parse_args()


def prepare_data(dataset, prog_args, train=False, pre_process=None):
    """
    preprocess TU dataset according to DiffPool's paper setting and load dataset into dataloader
    """
    if train:
        shuffle = True
    else:
        shuffle = False

    if pre_process:
        pre_process(dataset, prog_args)

    # dataset.set_fold(fold)
    return dgl.dataloading.GraphDataLoader(
        dataset,
        batch_size=prog_args.batch_size,
        shuffle=shuffle,
        num_workers=prog_args.n_worker,
    )


def graph_classify_task(prog_args):
    """
    perform graph classification task
    """

    dataset = tu.LegacyTUDataset(name=prog_args.dataset)
    train_size = int(prog_args.train_ratio * len(dataset))
    test_size = int(prog_args.test_ratio * len(dataset))
    val_size = int(len(dataset) - train_size - test_size)

    dataset_train, dataset_val, dataset_test = torch.utils.data.random_split(
        dataset, (train_size, val_size, test_size)
    )
    train_dataloader = prepare_data(
        dataset_train, prog_args, train=True, pre_process=pre_process
    )
    val_dataloader = prepare_data(
        dataset_val, prog_args, train=False, pre_process=pre_process
    )
    test_dataloader = prepare_data(
        dataset_test, prog_args, train=False, pre_process=pre_process
    )
    input_dim, label_dim, max_num_node = dataset.statistics()
    print("++++++++++STATISTICS ABOUT THE DATASET")
    print("dataset feature dimension is", input_dim)
    print("dataset label dimension is", label_dim)
    print("the max num node is", max_num_node)
    print("number of graphs is", len(dataset))
    # assert len(dataset) % prog_args.batch_size == 0, "training set not divisible by batch size"

    hidden_dim = 64  # used to be 64
    embedding_dim = 64

    # calculate assignment dimension: pool_ratio * largest graph's maximum
    # number of nodes  in the dataset
    assign_dim = int(max_num_node * prog_args.pool_ratio)
    print("++++++++++MODEL STATISTICS++++++++")
    print("model hidden dim is", hidden_dim)
    print("model embedding dim for graph instance embedding", embedding_dim)
    print("initial batched pool graph dim is", assign_dim)
    activation = F.relu

    # initialize model
    # 'diffpool' : diffpool
    model = DiffPool(
        input_dim,
        hidden_dim,
        embedding_dim,
        label_dim,
        activation,
        prog_args.gc_per_block,
        prog_args.dropout,
        prog_args.num_pool,
        prog_args.linkpred,
        prog_args.batch_size,
        "meanpool",
        assign_dim,
        prog_args.pool_ratio,
    )

    if prog_args.load_epoch >= 0 and prog_args.save_dir is not None:
        model.load_state_dict(
            torch.load(
                prog_args.save_dir
                + "/"
                + prog_args.dataset
                + "/model.iter-"
                + str(prog_args.load_epoch)
            )
        )

    print("model init finished")
    print("MODEL:::::::", prog_args.method)
    if prog_args.cuda:
        model = model.cuda()

    logger = train(
        train_dataloader, model, prog_args, val_dataset=val_dataloader
    )
    result = evaluate(test_dataloader, model, prog_args, logger)
    print("test  accuracy {:.2f}%".format(result * 100))


def train(dataset, model, prog_args, same_feat=True, val_dataset=None):
    """
    training function
    """
    dir = prog_args.save_dir + "/" + prog_args.dataset
    if not os.path.exists(dir):
        os.makedirs(dir)
    dataloader = dataset
    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()), lr=0.001
    )
    early_stopping_logger = {"best_epoch": -1, "val_acc": -1}

    if prog_args.cuda > 0:
        torch.cuda.set_device(0)
    for epoch in range(prog_args.epoch):
        begin_time = time.time()
        model.train()
        accum_correct = 0
        total = 0
        print("\nEPOCH ###### {} ######".format(epoch))
        computation_time = 0.0
        for (batch_idx, (batch_graph, graph_labels)) in enumerate(dataloader):
            for (key, value) in batch_graph.ndata.items():
                batch_graph.ndata[key] = value.float()
            graph_labels = graph_labels.long()
            if torch.cuda.is_available():
                batch_graph = batch_graph.to(torch.cuda.current_device())
                graph_labels = graph_labels.cuda()

            model.zero_grad()
            compute_start = time.time()
            ypred = model(batch_graph)
            indi = torch.argmax(ypred, dim=1)
            correct = torch.sum(indi == graph_labels).item()
            accum_correct += correct
            total += graph_labels.size()[0]
            loss = model.loss(ypred, graph_labels)
            loss.backward()
            batch_compute_time = time.time() - compute_start
            computation_time += batch_compute_time
            nn.utils.clip_grad_norm_(model.parameters(), prog_args.clip)
            optimizer.step()

        train_accu = accum_correct / total
        print(
            "train accuracy for this epoch {} is {:.2f}%".format(
                epoch, train_accu * 100
            )
        )
        elapsed_time = time.time() - begin_time
        print(
            "loss {:.4f} with epoch time {:.4f} s & computation time {:.4f} s ".format(
                loss.item(), elapsed_time, computation_time
            )
        )
        global_train_time_per_epoch.append(elapsed_time)
        if val_dataset is not None:
            result = evaluate(val_dataset, model, prog_args)
            print("validation  accuracy {:.2f}%".format(result * 100))
            if (
                result >= early_stopping_logger["val_acc"]
                and result <= train_accu
            ):
                early_stopping_logger.update(best_epoch=epoch, val_acc=result)
                if prog_args.save_dir is not None:
                    torch.save(
                        model.state_dict(),
                        prog_args.save_dir
                        + "/"
                        + prog_args.dataset
                        + "/model.iter-"
                        + str(early_stopping_logger["best_epoch"]),
                    )
            print(
                "best epoch is EPOCH {}, val_acc is {:.2f}%".format(
                    early_stopping_logger["best_epoch"],
                    early_stopping_logger["val_acc"] * 100,
                )
            )
        torch.cuda.empty_cache()
    return early_stopping_logger


def evaluate(dataloader, model, prog_args, logger=None):
    """
    evaluate function
    """
    if logger is not None and prog_args.save_dir is not None:
        model.load_state_dict(
            torch.load(
                prog_args.save_dir
                + "/"
                + prog_args.dataset
                + "/model.iter-"
                + str(logger["best_epoch"])
            )
        )
    model.eval()
    correct_label = 0
    with torch.no_grad():
        for batch_idx, (batch_graph, graph_labels) in enumerate(dataloader):
            for (key, value) in batch_graph.ndata.items():
                batch_graph.ndata[key] = value.float()
            graph_labels = graph_labels.long()
            if torch.cuda.is_available():
                batch_graph = batch_graph.to(torch.cuda.current_device())
                graph_labels = graph_labels.cuda()
            ypred = model(batch_graph)
            indi = torch.argmax(ypred, dim=1)
            correct = torch.sum(indi == graph_labels)
            correct_label += correct.item()
    result = correct_label / (len(dataloader) * prog_args.batch_size)
    return result


def main():
    """
    main
    """
    prog_args = arg_parse()
    print(prog_args)
    graph_classify_task(prog_args)

    print(
        "Train time per epoch: {:.4f}".format(
            sum(global_train_time_per_epoch) / len(global_train_time_per_epoch)
        )
    )
    print(
        "Max memory usage: {:.4f}".format(
            torch.cuda.max_memory_allocated(0) / (1024 * 1024)
        )
    )


if __name__ == "__main__":
    main()
