import functools
from typing import Any, Callable, Iterator, Literal, TypeVar, Optional, Union
import pickle
from pathlib import Path

import torch
import matplotlib.pyplot as plt

from vis_models.io.persistence import get_best_checkpoint_path
from vis_models.training.supervised import SupervisedLearning
from vis_models.io.dirs import get_results_dir


def load_model(
    task_name: list[str],
    model: torch.nn.Module,
) -> torch.nn.Module:
    best_checkpoint = get_best_checkpoint_path(task_name)
    task = SupervisedLearning.load_from_checkpoint(
        str(best_checkpoint), model=model,
    )
    return task.model

def load_models(
    task_name: list[str],
    models: dict[str, torch.nn.Module],
) -> dict[str, torch.nn.Module]:
    print(f"Loading previous models for task {task_name}")
    return {
        model_name: load_model(
            [*task_name, model_name],
            model,
        )
        for model_name, model in models.items()
    }


DEFAULT_RESULT_NAME = "result"

def save_result(
    exp_name: list[str],
    result: Any,
    result_name: str = DEFAULT_RESULT_NAME,
) -> None:
    results_dir = get_results_dir(exp_name)
    results_dir.mkdir(parents=True, exist_ok=True)
    with open(results_dir / f"{result_name}.pkl", "wb") as res_file:
        pickle.dump(result, res_file)

def load_result(exp_name: list[str], result_name: str) -> Any:
    results_dir = get_results_dir(exp_name)
    with open(results_dir / f"{result_name}.pkl", "rb") as res_file:
        result = pickle.load(res_file)
    return result

def _result_exists(exp_name: list[str], result_name: str) -> bool:
    results_dir = get_results_dir(exp_name)
    results_file = results_dir / f"{result_name}.pkl"
    return results_file.exists()

def load_experiment_result(
    exp_name: list[str],
    seed: tuple[int, int],
    postfix: list[str] = [],
    result_name: str = DEFAULT_RESULT_NAME,
) -> object:
    config_seed, sampling_seed = seed
    return load_result(
        get_experiment_name(exp_name, config_seed, sampling_seed, postfix),
        result_name,
    )

RerunCondition = Literal["always", "no_prior_res", "never"]
Res = TypeVar("Res")

def cached_result(
    func: Callable[..., Iterator[Res]]
) -> Callable[..., Optional[Res]]:
    @functools.wraps(func)
    def wrapper_func(
        exp_name: list[str],
        rerun_if: RerunCondition,
        *args,
        result_name: str = "result",
        **kwargs,
    ) -> Optional[Res]:
        if rerun_if not in {"always", "no_prior_res", "never"}:
            raise ValueError()

        result: Optional[Res] = None
        if (
            rerun_if == "always"
            or (
                rerun_if == "no_prior_res"
                and not _result_exists(exp_name, result_name)
            )
        ):
            for result in func(*args, **kwargs):
            # result = func(*args, **kwargs)
                save_result(exp_name, result, result_name)
        elif rerun_if != "never":
            print("Loading cached results")
            result = load_result(exp_name, result_name)
        return result
    return wrapper_func


def get_experiment_name(
    prefixes: list[str],
    config_seed: int,
    sampling_seed: int,
    postfixes: list[str] = [],
) -> list[str]:
    return [*prefixes, f"cs_{config_seed}_ss_{sampling_seed}", *postfixes]


def save_figure(
    figure: plt.Figure,
    output_file: Union[Path, str],
) -> None:
    figure.savefig(output_file)
