import os
import glob
from spaghettini import quick_register
import numpy as np
from munch import Munch
import pathlib
import datetime
import wandb

from src.utils.misc import getnow

POSTPATTERN = ".ckpt"
LOGGER_ID_FILENAME = "logger_id.txt"
EXPNAME_FILENAME = "experiment_name.txt"
SUPPORTED_CHECKPOINT_TYPES = ["recent", "best"]
RECENT_CHECKPOINT_EXTENSION = "ckpt"
BEST_CHECKPOINT_EXTENSION = "bestckpt"


@quick_register
def load_numpy_array():
    def load_npy(path):
        return np.load(path)

    return load_npy


def get_most_recent_checkpoint_filepath(dirs_dict, checkpoint_type):
    # The only allowed types of checkpoints are "recent" and "best". Enforce this.
    if checkpoint_type not in SUPPORTED_CHECKPOINT_TYPES:
        message = f"Checkpoint type {checkpoint_type} not amongst supported tupes {SUPPORTED_CHECKPOINT_TYPES}."
        raise ValueError(message)

    # Depending on which checkpoint type to load, look for a different extension.
    extension = RECENT_CHECKPOINT_EXTENSION if checkpoint_type == "recent" else BEST_CHECKPOINT_EXTENSION

    # Search the temporary checkpoint directory.
    if dirs_dict.tmp_ckpt_dir_abs is not None:
        # Find the most recent checkpoint, if there is one.
        tmp_ckpt_paths = get_date_sorted_ckpt_paths(dirs_dict.tmp_dir, extension=extension)
        if len(tmp_ckpt_paths) == 0:  # No checkpoint found.
            pass
        else:  # Return the most recent.
            print(f"\n Found checkpoint at {tmp_ckpt_paths[0]}\n")
            return tmp_ckpt_paths[0]

    # Search the master checkpoint directory.
    master_ckpt_paths = get_date_sorted_ckpt_paths(dirs_dict.cfg_dir_rel, extension=extension)
    if len(master_ckpt_paths) == 0:  # No checkpoint found.
        pass
    else:  # Return the most recent.
        print(f"\n Found checkpoint at {master_ckpt_paths[0]}\n")
        return master_ckpt_paths[0]

    print("No available checkpoint found. ")

    return None


def get_date_sorted_ckpt_paths(dirpath, extension="ckpt"):
    ckpt_paths = [f for f in glob.glob(os.path.join(dirpath, f"**/*.{extension}"), recursive=True)]

    # Find when they were created.
    ctimes = list()
    for fpath in ckpt_paths:
        ctime = datetime.datetime.fromtimestamp(pathlib.Path(fpath).stat().st_ctime)
        ctimes.append(ctime)

    # Sort them by creating date and return.
    return sorted(ckpt_paths,
                  key=lambda fpath: datetime.datetime.fromtimestamp(pathlib.Path(fpath).stat().st_ctime),
                  reverse=True)


def get_dirs_dict(cfg_dir_rel, tmp_dir):
    dirs_dict = Munch.fromDict(dict())

    # The python file is run from the project directory.
    dirs_dict.project_dir_abs = os.getcwd()

    # Relative config directory.
    dirs_dict.cfg_dir_rel = os.path.split(cfg_dir_rel)[0]

    # Temporary storage directory.
    dirs_dict.tmp_dir = tmp_dir

    # Master checkpoint directory.
    now = getnow()
    dirs_dict.ckpt_dir_rel = os.path.join(dirs_dict.cfg_dir_rel, f"{now}", "checkpoints")
    os.makedirs(dirs_dict.ckpt_dir_rel, exist_ok=True)

    # Temporary checkpoint directory.
    dirs_dict.tmp_ckpt_dir_abs = os.path.join(tmp_dir, f"{now}", "checkpoints")
    os.makedirs(dirs_dict.tmp_ckpt_dir_abs, exist_ok=True)

    return dirs_dict


def get_logger_id(load_ckpt_filepath):
    """Get the logger id from the checkpoint path."""
    # Determine what type of checkpoint is being loaded (i.e. recent or best).
    if load_ckpt_filepath.endswith(f".{RECENT_CHECKPOINT_EXTENSION}"):
        ckpt_type = "recent"
    elif load_ckpt_filepath.endswith(f".{BEST_CHECKPOINT_EXTENSION}"):
        ckpt_type = "best"
    else:
        message = f"The checkpoint type cannot be inferred from its extension. " \
                  f"Supported extensions are '{RECENT_CHECKPOINT_EXTENSION}' and '{BEST_CHECKPOINT_EXTENSION}'"
        raise ValueError(message)

    # If the checkpoint that's being loaded is a "recent" checkpoint, then the file storing the logger ID is at the
    # Same directory level as it.
    if ckpt_type == "recent":
        ckpt_dir = os.path.split(load_ckpt_filepath)[0]
    elif ckpt_type == "best":
        # If the checkpoint that's being loaded is a "best" checkpoint, then the file storing the logger ID is one
        # directory above.
        dir_components = load_ckpt_filepath.split("/")
        ckpt_dir = os.path.join(*dir_components[:-2])
    else:
        message = f"Only checkpoint types 'best' and 'recent' are supported."
        raise ValueError(message)

    logger_id_filepath = os.path.join(ckpt_dir, LOGGER_ID_FILENAME)
    expnane_filepath = os.path.join(ckpt_dir, EXPNAME_FILENAME)

    if os.path.exists(logger_id_filepath):
        with open(logger_id_filepath, "r") as f:
            logger_id = f.read()
    else:
        logger_id = wandb.util.generate_id()
        print(f"Logger id cannot be found, even though a checkpoint directory was found. Generated new id: {logger_id}")

    if os.path.exists(expnane_filepath):
        with open(expnane_filepath, "r") as f:
            expname = f.read()
    else:
        print("Experiment name could not be found. Will regenerate. ")
        expname = None

    print(f"\n Logger id found: {logger_id}. Experiment name: {expname} \n")
    return logger_id, expname


def save_logger_id(logger, tmp_ckpt_dir_abs, ckpt_dir_rel):
    logger_id_filename = LOGGER_ID_FILENAME
    logger_id_filepath_master = os.path.join(tmp_ckpt_dir_abs, logger_id_filename)
    logger_id_filepath_tmp = os.path.join(ckpt_dir_rel, logger_id_filename)

    expname_filename = EXPNAME_FILENAME
    expname_filepath_master = os.path.join(tmp_ckpt_dir_abs, expname_filename)
    expname_id_filepath_tmp = os.path.join(ckpt_dir_rel, expname_filename)

    logger_id = logger.version
    expname = logger.experiment.name

    print(f"\nSaving logger id: {logger_id} \n")
    # Save logger id to master dir. k
    with open(logger_id_filepath_master, "w") as f:
        f.write(logger_id)

    # Save loger id to temporary checkpoint dir.
    with open(logger_id_filepath_tmp, "w") as f:
        f.write(logger_id)

    print(f"\n Saving experiment name: {expname} \n")
    # Save expname to master dir.
    with open(expname_filepath_master, "w") as f:
        f.write(expname)

    # Save loger id to temporary checkpoint dir.
    with open(expname_id_filepath_tmp, "w") as f:
        f.write(expname)
