import logging
import importlib
import os
import time

import torch

from args import parse_args

from utils.training_utils import get_test_acc, analyze_acc, get_optimizer_scheduler
from utils.general_utils import parse_configs_file, setup_seed, save_checkpoint, get_exp_name, get_data_per_epoch, plot_trajectory
import datasets
import models


def main():
    args = parse_args()
    parse_configs_file(args)

    setup_seed(args.seed)

    # All kinds of dir names
    exp_name = get_exp_name(args)
    result_sub_dir = os.path.join(
        args.result_dir, exp_name
    )
    graph_dir = os.path.join(
        args.result_dir, "graphs"
    )

    os.makedirs(result_sub_dir, exist_ok=True)
    os.makedirs(graph_dir, exist_ok=True)

    # add logger
    logging.basicConfig(level=logging.INFO, format="%(message)s")
    logger = logging.getLogger()
    logger.addHandler(
        logging.FileHandler(os.path.join(result_sub_dir, "setup.log"), "a")
    )
    logger.info(args)

    # Select GPUs
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    gpu_list = [int(i) for i in args.gpu.strip().split(",")]
    device = torch.device(f"cuda:{gpu_list[0]}" if use_cuda else "cpu")

    env_num = 1 if args.trainer != "BLO" else len(args.training_env)

    model = models.__dict__[args.arch](
        env_num=env_num,
        use_color=True if len(args.training_color_env) > 0 else False
    )

    model.load_device(device)

    # Dataloader
    D = datasets.__dict__[args.dataset](args)
    train_loader, val_loader, test_loader = D.data_loaders()

    # Trainer
    trainer = importlib.import_module(f"trainer.{args.trainer}").train

    # Optimizer & Scheduler
    optimizer, scheduler = get_optimizer_scheduler(model, args)

    # Train loop
    best_epoch = 0
    best_diff = 100.0
    best_acc = 0.0
    args.start_epoch = 0
    data_per_epoch = [[] for _ in range(10)]
    for epoch in range(args.start_epoch, args.epochs):
        start_time = time.time()

        train_stat = trainer(
            model, args, device, train_loader, optimizer, scheduler, epoch
        )

        test_accuracy = get_test_acc(model, test_loader, device)
        acc = analyze_acc(test_accuracy)

        logger.info("This epoch duration :{}s".format(time.time() - start_time))

        if best_diff > acc[0] - acc[1]:
            best_acc = acc[-1]
            best_diff = acc[0] - acc[1]
            best_epoch = epoch
            is_best = True
        else:
            is_best = False

        logger.info("For epoch {}, the best acc is {:.4f} and the diff acc is {:.4f} at epoch {}".format(epoch, best_acc, best_diff, best_epoch))

        if is_best:
            torch.save(
                {
                    "training_stat": train_stat,
                    "acc": acc,
                    "all_res": test_accuracy
                },
                os.path.join(result_sub_dir, "trajectory.pt")
            )

        get_data_per_epoch(data_per_epoch, train_stat, acc)

        plot_trajectory(data_per_epoch, [os.path.join(graph_dir, f"{exp_name}.png"), os.path.join(result_sub_dir, f"{exp_name}.png")])

        torch.save(data_per_epoch, os.path.join(result_sub_dir, "data.pth"))

        if args.save:
            save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "arch": args.arch,
                    "state_dict": model.state_dict(),
                    "best_acc": best_acc,
                    "diff_acc": best_diff,
                    "optimizer": optimizer.state_dict(),
                    "scheduler": scheduler.state_dict()
                },
                is_best,
                result_dir=os.path.join(result_sub_dir, "checkpoint"),
            )


if __name__ == "__main__":
    main()

