import os
import glob
from spaghettini import quick_register, load
import numpy as np
from munch import Munch
import pathlib
import datetime

import torch

from src.utils.misc import getnow

POSTPATTERN = ".ckpt"
LOGGER_ID_FILENAME = "logger_id.txt"


def get_system_with_trained_weights(cfg_path_rel, tmp_ckpt_dir_abs):
    # ____ Get the system. ____
    cfg = load(cfg_path_rel, verbose=False, record_config=False)
    system = cfg.system

    # ____ Get the directories dict. ____
    dirs_dict = get_dirs_dict(cfg_dir_rel=cfg_path_rel, tmp_dir=tmp_ckpt_dir_abs)

    # ____ Find the checkpoint path. ____
    load_ckpt_filepath = get_most_recent_checkpoint_filepath(dirs_dict)
    assert load_ckpt_filepath is not None, print("No checkpoint found. Terminating. ")

    # ____ Load the weights. ____
    checkpoint = torch.load(f=load_ckpt_filepath, map_location=lambda storage, loc: storage)
    system.load_state_dict(checkpoint["state_dict"])

    return system


def get_most_recent_checkpoint_filepath(dirs_dict):
    # Search the temporary checkpoint directory.
    if dirs_dict.tmp_ckpt_dir_abs is not None:
        # Find the most recent checkpoint, if there is one.
        tmp_ckpt_paths = get_date_sorted_ckpt_paths(dirs_dict.tmp_dir, extension="ckpt")
        if len(tmp_ckpt_paths) == 0:  # No checkpoint found.
            pass
        else:  # Return the most recent.
            print(f"\n Found checkpoint at {tmp_ckpt_paths[0]}\n")
            return tmp_ckpt_paths[0]

    # Search the master checkpoint directory.
    master_ckpt_paths = get_date_sorted_ckpt_paths(dirs_dict.cfg_dir_rel, extension="ckpt")
    if len(master_ckpt_paths) == 0:  # No checkpoint found.
        pass
    else:  # Return the most recent.
        print(f"\n Found checkpoint at {master_ckpt_paths[0]}\n")
        return master_ckpt_paths[0]

    print("No available checkpoint found. ")

    return None


def get_dirs_dict(cfg_dir_rel, tmp_dir):
    dirs_dict = Munch.fromDict(dict())

    # The python file is run from the project directory.
    dirs_dict.project_dir_abs = os.getcwd()

    # Relative config directory.
    dirs_dict.cfg_dir_rel = os.path.split(cfg_dir_rel)[0]

    # Temporary storage directory.
    dirs_dict.tmp_dir = tmp_dir

    # Master checkpoint directory.
    now = getnow()
    dirs_dict.ckpt_dir_rel = os.path.join(dirs_dict.cfg_dir_rel, f"{now}", "checkpoints")
    os.makedirs(dirs_dict.ckpt_dir_rel, exist_ok=True)

    # Temporary checkpoint directory.
    dirs_dict.tmp_ckpt_dir_abs = os.path.join(tmp_dir, f"{now}", "checkpoints")
    os.makedirs(dirs_dict.tmp_ckpt_dir_abs, exist_ok=True)

    return dirs_dict


def get_date_sorted_ckpt_paths(dirpath, extension="ckpt"):
    ckpt_paths = [f for f in glob.glob(os.path.join(dirpath, f"**/*.{extension}"), recursive=True)]

    # Find when they were created.
    ctimes = list()
    for fpath in ckpt_paths:
        ctime = datetime.datetime.fromtimestamp(pathlib.Path(fpath).stat().st_ctime)
        ctimes.append(ctime)

    # Sort them by creating date and return.
    return sorted(ckpt_paths,
                  key=lambda fpath: datetime.datetime.fromtimestamp(pathlib.Path(fpath).stat().st_ctime),
                  reverse=True)


def get_logger_id(load_ckpt_filepath):
    ckpt_dir = os.path.split(load_ckpt_filepath)[0]
    filepath = os.path.join(ckpt_dir, LOGGER_ID_FILENAME)
    with open(filepath, "r") as f:
        logger_id = f.read()

    print(f"\n Logger id found: {logger_id} \n")
    return logger_id


def save_logger_id(logger, tmp_ckpt_dir_abs, ckpt_dir_rel):
    filename = LOGGER_ID_FILENAME
    filepath_master = os.path.join(tmp_ckpt_dir_abs, filename)
    filepath_tmp = os.path.join(ckpt_dir_rel, filename)

    logger_id = logger.version
    print(f"\nSaving logger id: {logger_id} \n")

    # Save logger id to master dir.
    with open(filepath_master, "w") as f:
        f.write(logger_id)

    # Save loger id to temporary checkpoint dir.
    with open(filepath_tmp, "w") as f:
        f.write(logger_id)


@quick_register
def load_numpy_array():
    def load_npy(path):
        return np.load(path)

    return load_npy

