import argparse
import datetime
import json
import os
import subprocess
import sys
from collections import OrderedDict

sys.path.insert(1, os.getcwd())
import datasets
import torch
from torch.utils.tensorboard import SummaryWriter

from dataloaders import P3CLDataModule, BBCLDataModule
from dataloaders.constants import P3_DATASET_CONFIGS, TAG2TASK_LIST, BB_DATASET_CONFIGS
from trainers.evaluator import Evaluator
from trainers.interface_mixin import InterfaceMixin
from utils.config import Config, ParseKwargs
from utils.get_model import hf_tokenizer, load_continual_learning_model
from utils.util import get_logger, setup_wandb_logger

datasets.disable_progress_bar()


def main(config, loggers):
    start = datetime.datetime.now()
    if config.lora_hub_eval and config.moe_inference:
        raise (ValueError("LoRA Hub evaluation is not supported for MOE inference."))

    model_load_results = load_continual_learning_model(config)
    model = model_load_results["model"]
    lora_module_path_list = model_load_results["path_list"]

    if "bigbench" in config.dataset[0] or "bb" in config.dataset[0]:
        data_type = "bigbench"
    else:
        data_type = "p3"

    if config.save_router_state_dict:
        torch.save(model.router_weight_state_dict(), os.path.join(config.checkpoint_dir, "router_state_dict.pt"))
        print(f"Router weight state dict saved.")

    tokenizer = hf_tokenizer(config.origin_model)
    model.interface = InterfaceMixin(model_type=config.model_type)

    if len(config.dataset) == 1 and config.dataset[0] in TAG2TASK_LIST.keys():
        config.dataset = TAG2TASK_LIST[config.dataset[0]]

    all_tasks = config.dataset
    dump_dict = OrderedDict()
    dump_dict["tasks"] = all_tasks

    if data_type == "bigbench":
        datamodule = BBCLDataModule(
            config, tokenizer, loggers, is_moe=config.moe_inference, stage=config.eval_split
        )
    elif data_type == "p3":
        datamodule = P3CLDataModule(
            config, tokenizer, loggers, is_moe=config.moe_inference, stage=config.eval_split
        )
    else:
        raise KeyError(f"Unknown data type {data_type}")

    final_evaluator = Evaluator(
        config=config,
        eval_tasks=all_tasks,
        tokenizer=tokenizer,
        datamodule=datamodule,
        loggers=loggers,
    )

    results = final_evaluator.eval_all(
        model, split=config.eval_split,
        lora_hub_eval=config.lora_hub_eval, lora_hub_module_path_list=lora_module_path_list,
    )

    dump_dict[f"{config.eval_split}_results"] = results

    loggers["logger"].info(results)
    if "wandb" in loggers.keys():
        loggers["wandb"].log(
            {
                f"{task}_test_{m}": f"{v:.4f}"
                for task, metrics in results.items()
                for m, v in metrics.items() if isinstance(v, float)
            }
        )

    config_dict = vars(config)
    del config_dict["device"]
    dump_dict["config"] = config_dict
    dataset_configs = OrderedDict()
    for data_tag in all_tasks:
        if data_type == "bigbench":
            dataset_configs[data_tag] = BB_DATASET_CONFIGS[data_tag]
        elif data_type == "p3":
            dataset_configs[data_tag] = P3_DATASET_CONFIGS[data_tag]
    dump_dict["dataset_configs"] = dataset_configs

    with open(os.path.join(config.run_output_dir, "metrics.json"), "w") as file:
        json.dump(dump_dict, file, indent=4)

    end = datetime.datetime.now()
    loggers["logger"].info(
        f"\nTook {(end - start) // datetime.timedelta(hours=1)} hours {(end - start) // datetime.timedelta(minutes=1)} minutes."
    )


def main_setup():
    parser = argparse.ArgumentParser()
    parser.add_argument("-c", "--config_files", required=True)
    parser.add_argument("-k", "--kwargs", nargs="*", action=ParseKwargs, default={})
    args = parser.parse_args()

    config = Config(args.config_files, args.kwargs)

    log_config = os.path.join(config.project_dir, "utils/")
    logger = get_logger("log.txt", f"{config.log_dir}/", log_config)

    logger.info(f"Start experiment {config.project_name}/{config.name}")
    logger.info(config.to_json())

    config.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    loggers = {"logger": logger}
    if not config.debug:
        loggers["tb"] = SummaryWriter(config.run_output_dir)
        loggers["wandb"], _, _ = setup_wandb_logger(config.__dict__)
        loggers["wandb"].log(
            {"command": subprocess.list2cmdline(["python"] + sys.argv)}
        )

    return config, loggers


if __name__ == "__main__":
    config, loggers = main_setup()
    main(config, loggers)

