import argparse
import spaghettini
from spaghettini import quick_register
import os
import json
import tqdm
from functools import partial
import wandb

import torch

from src.mains.task_getters import get_system_and_trainer
from src.utils.misc import stdlog

from src.analysis.logging.common_logging_utils import LogAccuracyWithDifferentForwardDepths
from src.analysis.logging.path_independence_logging import PathIndependenceQuantifying
from src.utils.saving_loading import get_most_recent_checkpoint_filepath, get_dirs_dict, get_logger_id, save_logger_id
from src.utils.misc import prepend_string_to_dict_keys, postpend_string_to_dict_keys

POSSIBLE_TEMPLATE_NAMES = ["template.yaml", "cfg.yaml"]
EVAL_LOGGERS = (partial(LogAccuracyWithDifferentForwardDepths, all_forward_iters=[4, 32, 256, 1024, 2048]),
                partial(PathIndependenceQuantifying, rel_diff_threshold=0.01, diff_l2_threshold=0.001,
                        num_forward_iter=2048)
                ,)


def _get_checkpoint_step(ckpt_path):
    return int(ckpt_path.split("__step_")[1].split("__2022")[0])


def use_gpu():
    return torch.cuda.is_available()


@quick_register
def test_til(eval_fns, topdir, eval_last_checkpoint=False):
    """Function to run path independence evaluation experiments."""
    # Gather all the checkpoints under expdir.
    all_checkpoints = gather_all_checkpoints(expdir=topdir)
    if len(all_checkpoints) == 0:
        print(f"Found no checkpoints under {topdir}. Terminating. ")
        exit(-1)

    # Sort the checkpoints based on their step.
    all_checkpoints = sorted(all_checkpoints, key=lambda x: _get_checkpoint_step(x))

    # If asked, only evaluate the last checkpoint.
    if eval_last_checkpoint:
        all_checkpoints = all_checkpoints[-1:]

    # Get the template path.
    template_found = False
    template_path = None
    for curr_template_name in POSSIBLE_TEMPLATE_NAMES:
        template_path = os.path.join(topdir, curr_template_name)
        if os.path.exists(template_path):
            template_found = True
            break
    if not template_found:
        print(f"No template found, which is needed to load the model. Aborting. ")
        exit(-1)

    # Initialize the logger.
    logger = init_logger(checkpoint_path=all_checkpoints[0], template_path=template_path)

    # Extract results from each checkpoint.
    results_dict = dict()
    for ckpt in tqdm.tqdm(all_checkpoints):
        print(f"Evaluating checkpoint {ckpt}.")
        curr_metric_logs = eval_checkpoint(eval_fns=eval_fns, expdir=topdir, checkpoint_path=ckpt,
                                           template_path=template_path)
        results_dict[ckpt] = curr_metric_logs


def eval_checkpoint(eval_fns, expdir, checkpoint_path, template_path, eval_loggers=EVAL_LOGGERS):
    # Load the spaghettini template for training.
    cfg, _ = spaghettini.load(path=template_path, gather_hparams=False, verbose=False, record_config=False)

    # Load the pytorch lightning system from checkpoint. Also load the checkpoint dict.
    pls = load_system_from_checkpoint(cfg=cfg, checkpoint_path=checkpoint_path, template_path=template_path)
    pls = pls.cuda() if use_gpu() else pls
    ckpt_dict = torch.load(f=checkpoint_path)

    # Extract relevant training state information.
    global_step, epoch = ckpt_dict["global_step"], ckpt_dict["epoch"]

    # Create directory to save the results under.
    save_dir = ".".join(checkpoint_path.split(".")[:-1]) + "_eval_results"
    os.makedirs(save_dir, exist_ok=True)

    # Run evaluation.
    prepend_key = f"test/"
    metric_logs = dict()
    for eval_logger in eval_fns:
        metric_logs = eval_logger(metric_logs=metric_logs, pl_system=pls, prepend_key=prepend_key, save_locally=True,
                                  save_dir=save_dir, global_step=global_step, epoch=epoch)
    metric_logs = prepend_string_to_dict_keys(prepend_key=prepend_key, dictinary=metric_logs)
    metric_logs.update({"trainer/global_step": global_step})

    wandb.log(metric_logs)

    # ____ Also save the metric logs as a json file. ____
    json_path = os.path.join(save_dir, "metric_logs.json")

    # Load the existing metric logs content.
    if os.path.exists(json_path):
        with open(json_path, "r") as f:
            existing_metric_logs = json.load(f)
    else:
        existing_metric_logs = dict()

    # Override/add to the content.
    existing_metric_logs.update(metric_logs)
    final_metric_logs = existing_metric_logs

    with open(json_path, "w") as f:
        json.dump(final_metric_logs, f, indent=2)

    return final_metric_logs


def init_logger(checkpoint_path, template_path):
    cfg, _ = spaghettini.load(path=template_path, gather_hparams=False, verbose=False, record_config=False)
    logger_id, found_exp_name = get_logger_id(load_ckpt_filepath=checkpoint_path)
    logger = cfg.logger(name=found_exp_name, save_dir=None, id=logger_id)

    # This initializes the wandb experiment with the right parameters.
    _ = logger.experiment

    return logger


def load_system_from_checkpoint(cfg, checkpoint_path, template_path):
    pls = cfg.system(dirs_dict=False)
    return pls.load_from_checkpoint(checkpoint_path, **cfg.system.initial_kwargs, dirs_dict=None)


def gather_all_checkpoints(expdir, ckpt_extension=".ckpt"):
    checkpoints = list()
    for root, dirs, files in os.walk(expdir):
        for file in files:
            if file.endswith(ckpt_extension):
                checkpoints.append(os.path.join(root, file))

    return sorted(checkpoints)


if __name__ == "__main__":
    """
    Run from root. 
    python -m src.mains.tasks.test_til --topdir="runs/dev/tmp"
    """
    parser = argparse.ArgumentParser()
    parser.add_argument("--topdir", type=str,
                        help="The topdir that contains the expdirs. ")
    parser.add_argument("--eval_last_checkpoint", action='store_true')
    args = parser.parse_args()

    # Load the eval config.
    eval_config_path = os.path.join(args.topdir, "eval_template.yaml")
    cfg, _ = spaghettini.load(eval_config_path, record_config=False)

    test_til(eval_fns=cfg.logging_functions, topdir=args.topdir, eval_last_checkpoint=args.eval_last_checkpoint)
