import argparse
import time

import torch
import torch.nn.functional as F
from dataset import EllipticDataset
from model import EvolveGCNH, EvolveGCNO
from utils import Measure

import dgl


def train(args, device):
    elliptic_dataset = EllipticDataset(
        raw_dir=args.raw_dir,
        processed_dir=args.processed_dir,
        self_loop=True,
        reverse_edge=True,
    )

    g, node_mask_by_time = elliptic_dataset.process()
    num_classes = elliptic_dataset.num_classes

    cached_subgraph = []
    cached_labeled_node_mask = []
    for i in range(len(node_mask_by_time)):
        # we add self loop edge when we construct full graph, not here
        node_subgraph = dgl.node_subgraph(graph=g, nodes=node_mask_by_time[i])
        cached_subgraph.append(node_subgraph.to(device))
        valid_node_mask = node_subgraph.ndata["label"] >= 0
        cached_labeled_node_mask.append(valid_node_mask)

    if args.model == "EvolveGCN-O":
        model = EvolveGCNO(
            in_feats=int(g.ndata["feat"].shape[1]),
            n_hidden=args.n_hidden,
            num_layers=args.n_layers,
        )
    elif args.model == "EvolveGCN-H":
        model = EvolveGCNH(
            in_feats=int(g.ndata["feat"].shape[1]), num_layers=args.n_layers
        )
    else:
        return NotImplementedError("Unsupported model {}".format(args.model))
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # split train, valid, test(0-30,31-35,36-48)
    # train/valid/test split follow the paper.
    train_max_index = 30
    valid_max_index = 35
    test_max_index = 48
    time_window_size = args.n_hist_steps
    loss_class_weight = [float(w) for w in args.loss_class_weight.split(",")]
    loss_class_weight = torch.Tensor(loss_class_weight).to(device)

    train_measure = Measure(
        num_classes=num_classes, target_class=args.eval_class_id
    )
    valid_measure = Measure(
        num_classes=num_classes, target_class=args.eval_class_id
    )
    test_measure = Measure(
        num_classes=num_classes, target_class=args.eval_class_id
    )

    test_res_f1 = 0
    for epoch in range(args.num_epochs):
        model.train()
        for i in range(time_window_size, train_max_index + 1):
            g_list = cached_subgraph[i - time_window_size : i + 1]
            predictions = model(g_list)
            # get predictions which has label
            predictions = predictions[cached_labeled_node_mask[i]]
            labels = (
                cached_subgraph[i]
                .ndata["label"][cached_labeled_node_mask[i]]
                .long()
            )
            loss = F.cross_entropy(
                predictions, labels, weight=loss_class_weight
            )
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_measure.append_measures(predictions, labels)

        # get each epoch measures during training.
        cl_precision, cl_recall, cl_f1 = train_measure.get_total_measure()
        train_measure.update_best_f1(cl_f1, epoch)
        # reset measures for next epoch
        train_measure.reset_info()

        print(
            "Train Epoch {} | class {} | precision:{:.4f} | recall: {:.4f} | f1: {:.4f}".format(
                epoch, args.eval_class_id, cl_precision, cl_recall, cl_f1
            )
        )

        # eval
        model.eval()
        for i in range(train_max_index + 1, valid_max_index + 1):
            g_list = cached_subgraph[i - time_window_size : i + 1]
            predictions = model(g_list)
            # get node predictions which has label
            predictions = predictions[cached_labeled_node_mask[i]]
            labels = (
                cached_subgraph[i]
                .ndata["label"][cached_labeled_node_mask[i]]
                .long()
            )

            valid_measure.append_measures(predictions, labels)

        # get each epoch measure during eval.
        cl_precision, cl_recall, cl_f1 = valid_measure.get_total_measure()
        valid_measure.update_best_f1(cl_f1, epoch)
        # reset measures for next epoch
        valid_measure.reset_info()

        print(
            "Eval Epoch {} | class {} | precision:{:.4f} | recall: {:.4f} | f1: {:.4f}".format(
                epoch, args.eval_class_id, cl_precision, cl_recall, cl_f1
            )
        )

        # early stop
        if epoch - valid_measure.target_best_f1_epoch >= args.patience:
            print(
                "Best eval Epoch {}, Cur Epoch {}".format(
                    valid_measure.target_best_f1_epoch, epoch
                )
            )
            break
        # if cur valid f1 score is best, do test
        if epoch == valid_measure.target_best_f1_epoch:
            print(
                "###################Epoch {} Test###################".format(
                    epoch
                )
            )
            for i in range(valid_max_index + 1, test_max_index + 1):
                g_list = cached_subgraph[i - time_window_size : i + 1]
                predictions = model(g_list)
                # get predictions which has label
                predictions = predictions[cached_labeled_node_mask[i]]
                labels = (
                    cached_subgraph[i]
                    .ndata["label"][cached_labeled_node_mask[i]]
                    .long()
                )

                test_measure.append_measures(predictions, labels)

            # we get each subgraph measure when testing to match fig 4 in EvolveGCN paper.
            (
                cl_precisions,
                cl_recalls,
                cl_f1s,
            ) = test_measure.get_each_timestamp_measure()
            for index, (sub_p, sub_r, sub_f1) in enumerate(
                zip(cl_precisions, cl_recalls, cl_f1s)
            ):
                print(
                    "  Test | Time {} | precision:{:.4f} | recall: {:.4f} | f1: {:.4f}".format(
                        valid_max_index + index + 2, sub_p, sub_r, sub_f1
                    )
                )

            # get each epoch measure during test.
            cl_precision, cl_recall, cl_f1 = test_measure.get_total_measure()
            test_measure.update_best_f1(cl_f1, epoch)
            # reset measures for next test
            test_measure.reset_info()

            test_res_f1 = cl_f1

            print(
                "  Test | Epoch {} | class {} | precision:{:.4f} | recall: {:.4f} | f1: {:.4f}".format(
                    epoch, args.eval_class_id, cl_precision, cl_recall, cl_f1
                )
            )

    print(
        "Best test f1 is {}, in Epoch {}".format(
            test_measure.target_best_f1, test_measure.target_best_f1_epoch
        )
    )
    if test_measure.target_best_f1_epoch != valid_measure.target_best_f1_epoch:
        print(
            "The Epoch get best Valid measure not get the best Test measure, "
            "please checkout the test result in Epoch {}, which f1 is {}".format(
                valid_measure.target_best_f1_epoch, test_res_f1
            )
        )


if __name__ == "__main__":
    argparser = argparse.ArgumentParser("EvolveGCN")
    argparser.add_argument(
        "--model",
        type=str,
        default="EvolveGCN-O",
        help="We can choose EvolveGCN-O or EvolveGCN-H,"
        "but the EvolveGCN-H performance on Elliptic dataset is not good.",
    )
    argparser.add_argument(
        "--raw-dir",
        type=str,
        default="/home/Elliptic/elliptic_bitcoin_dataset/",
        help="Dir after unzip downloaded dataset, which contains 3 csv files.",
    )
    argparser.add_argument(
        "--processed-dir",
        type=str,
        default="/home/Elliptic/processed/",
        help="Dir to store processed raw data.",
    )
    argparser.add_argument(
        "--gpu",
        type=int,
        default=0,
        help="GPU device ID. Use -1 for CPU training.",
    )
    argparser.add_argument("--num-epochs", type=int, default=1000)
    argparser.add_argument("--n-hidden", type=int, default=256)
    argparser.add_argument("--n-layers", type=int, default=2)
    argparser.add_argument(
        "--n-hist-steps",
        type=int,
        default=5,
        help="If it is set to 5, it means in the first batch,"
        "we use historical data of 0-4 to predict the data of time 5.",
    )
    argparser.add_argument("--lr", type=float, default=0.001)
    argparser.add_argument(
        "--loss-class-weight",
        type=str,
        default="0.35,0.65",
        help="Weight for loss function. Follow the official code,"
        "we need to change it to 0.25, 0.75 when use EvolveGCN-H",
    )
    argparser.add_argument(
        "--eval-class-id",
        type=int,
        default=1,
        help="Class type to eval. On Elliptic, type 1(illicit) is the main interest.",
    )
    argparser.add_argument(
        "--patience", type=int, default=100, help="Patience for early stopping."
    )

    args = argparser.parse_args()

    if args.gpu >= 0:
        device = torch.device("cuda:%d" % args.gpu)
    else:
        device = torch.device("cpu")

    start_time = time.perf_counter()
    train(args, device)
    print("train time is: {}".format(time.perf_counter() - start_time))
