"""
Miscellaneous helpers that aren't clearly meant for only online or offline learning.
"""

from datetime import datetime
import os
import pickle

import yaml


def rollout_policy_in_env(env, policy):
    """
    Rollout a policy in an environment.

    Returns trajectory [s1, a1, r1, s1'(=s2), a2, r2, s2'(=s3), ..., aN,rN]

    (maybe): TODO: assumes TabularMDP, namely that env.observation_type == 'state',
    i.e. that obs is 1D NP array [state] where state is integer.
    Need to change: if env is e.g. Minigrid, then must use obs for policy
    but state for saving to trajectory(??? really?)
    """

    obs = env.reset()
    done = False
    traj = [obs[0]]
    while not done:
        a = policy.get_action(obs)
        obs, reward, done, _ = env.step(a)
        traj.extend([a, reward, obs[0]])  # obs[0] b/c TabularMDP
    traj.pop()  # removes very last obs
    return traj  # [s1,a1,r1,s1'(=s2),a2,r2,s2'(=s3), ..., aN,rN]


def pad_metrics(metrics, params):
    """
    if loop terminated early in one seed, then some of the metrics don't have the right length.
    fix this by padding them at the end: all with 0, except avg_rewards_best_iteration_policy, avg_rewards_true_opt, and pi_set_sizes with 1
    """
    no_padding_needed = True
    for seed in range(len(metrics)):
        for metric in metrics[seed]:
            if len(metrics[seed][metric]) < params["N_iterations"]:
                padding_length = params["N_iterations"] - len(metrics[seed][metric])
                if (
                    metric == "avg_rewards_best_iteration_policy"
                    or metric == "avg_rewards_true_opt"
                ):
                    metrics[seed][metric].extend([1] * padding_length)
                elif metric == "pi_set_sizes":
                    metrics[seed][metric].extend([1] * padding_length)
                else:
                    metrics[seed][metric].extend([0] * padding_length)
                if "plotting" in params["verbose"] or "full" in params["verbose"]:
                    print(f"metric {metric} padded {padding_length} for seed {seed}")
                no_padding_needed = False
    if no_padding_needed and ("plotting" in params["verbose"] or "full" in params["verbose"]):
        print("no padding needed")
    return metrics


def short_env_name(long_env_name):
    if long_env_name == "StarMDP_with_random_flinging":
        return "StarMDP"
    elif long_env_name == "StarMDP_with_random_staying":
        return "StarMDPStaying"
    elif long_env_name == "Gridworld":
        return "Gridworld"
    elif long_env_name == "HalfCheetah-v5" or long_env_name == "halfcheetah":
        return "HalfCheetah"
    elif long_env_name == "Reacher-v5" or long_env_name == "reacher":
        return "Reacher"
    elif long_env_name in ["Ant-v5", "ant"]:
        return "Ant"
    elif long_env_name in ["Walker2d-v5", "walker"]:
        return "Walker"
    elif long_env_name in ["Hopper-v5", "hopper"]:
        return "Hopper"
    elif long_env_name in ["Humanoid-v5", "humanoid"]:
        return "Humanoid"
    else:
        raise ValueError(f"Unknown environment: {long_env_name}")


def get_run_dir(params):
    """
    looks at params and returns the directory where the run will be saved. does NOT create dir.
    if run_ID is None: return path of dir with next higher available 3-digit ID
    if run_ID is specified: return path of dir with run_ID
    """
    env_name = short_env_name(params.get("env") or params.get("env_id"))
    os.makedirs(f"exps/{env_name}", exist_ok=True)  # ensure dir for env exists
    # if user specified a run_ID, check if dir for run exists -- if yes: return it
    if params.get("run_ID") is not None:
        run_dir = f"exps/{env_name}/{params.get('run_ID')}"

    # if not, find the next available ID (find highest 3-digit number in folder names of the save directory, increment by 1)
    else:
        # get all folders in figs/{env_name}
        folders = [
            f for f in os.listdir(f"exps/{env_name}") if os.path.isdir(f"exps/{env_name}/{f}")
        ]
        # get the highest three-digit number
        highest_number = (
            max([int(f.split("_")[0]) for f in folders if f.isdigit()]) if len(folders) > 0 else 0
        )
        next_id = f"{highest_number + 1:03d}"  # id+1
        run_dir = f"exps/{env_name}/{next_id}"
    return run_dir


def get_fig_names(params):
    """
    create the names of the figs that will be saved. names contain some run characteristics, and a timestamp.
    """
    env_name = short_env_name(params.get("env") or params.get("env_id"))

    if env_name in ["StarMDP", "StarMDPStaying", "Gridworld"]:
        if params["offlineradius_formula"] == "full":
            radius = "full"
        elif params["offlineradius_formula"] == "ignore_bracket":
            radius = "IB"
        elif params["offlineradius_formula"] == "only_alpha":
            radius = "OA"
        elif params["offlineradius_formula"] == "hardcode_radius_scaled":
            radius = "HCscaled-" + str(params["offlineradius_override_value"])
        elif params["offlineradius_formula"] == "hardcode_radius":
            radius = "HC-" + str(params["offlineradius_override_value"])
        else:
            raise ValueError(f"Unknown offlineradius formula: {params['offlineradius_formula']}")

        if params["gamma_t_hardcoded_value"]:
            gamma_t = params["gamma_t_hardcoded_value"]
        else:
            gamma_t = "calc"

        bonus_multiplier = params["online_confset_bonus_multiplier"]

        env_move_prob_str = str(params["env_move_prob"]).replace(".", "-")

        if params["which_confset_construction_method"] == "noise-matrices":
            confsize = "NM" + str(params["N_confset_size"])
        elif params["which_confset_construction_method"] == "rejection-sampling-from-all":
            confsize = "RSall"
        elif params["which_confset_construction_method"] == "rejection-sampling-from-sample":
            confsize = "RSsamp"
        else:
            raise ValueError(
                f"Unknown confset construction method: {params['which_confset_construction_method']}"
            )

        stem = (
            f"{env_name}_mp{env_move_prob_str}_"
            f"F{params['phi_name']}_"
            f"T{params['N_offline_trajs']}_"
            f"H{params['episode_length']}_"
            f"rad{radius}_"
            f"exp{params['N_experiments']}_"
            f"its{params['N_iterations']}_"
            f"gamma{gamma_t}_"
            f"B{bonus_multiplier}_"
            f"confset{confsize}"
        )
    elif env_name in ["HalfCheetah", "Reacher", "Ant", "Hopper", "Walker"]:
        if params["which_policy_selection"] == "random":
            polsel = "Rand"
        elif params["which_policy_selection"] == "ucb":
            polsel = "UCB"
        elif params["which_policy_selection"] == "max_uncertainty":
            polsel = "MaxUnc"
        else:
            raise ValueError(f"Unknown policy selection method: {params['which_policy_selection']}")
        stem = (
            f"{env_name}_"
            f"F{params['embedding_name']}_"
            f"T{params['N_offline_trajs']}_"
            f"H{params['episode_length']}_"
            f"exp{params['N_experiments']}_"
            f"its{params['N_iterations']}_"
            f"confset{params['N_confset_size']}_"
            f"PolSel{polsel}_"
        )

    timestamp = datetime.now().strftime("%m%d_%H%M%S")
    subopt_type_mapping = {
        "suboptimality_percent": "subopt_percent",
        "regret": "iter_regret",
        "cumulative_regret": "cum_regret",
        "raw_reward": "raw_reward",
        "regret_indiv": "iter_regret_indiv",
    }

    fig_names_subopt = []
    if isinstance(params["which_plot_subopt"], str):
        params["which_plot_subopt"] = [params["which_plot_subopt"]]
    for plot_type in params["which_plot_subopt"]:
        if plot_type not in subopt_type_mapping:
            raise ValueError(f"Unknown subopt type: {plot_type}")
        subopt_type = subopt_type_mapping[plot_type]
        fig_names_subopt.append(f"{stem}_{subopt_type}_{timestamp}.png")

    fig_name_pi_set_sizes = f"{stem}_pi_set_sizes_{timestamp}.png"
    fig_name_mujoco = f"{stem}_mujocometrics_{timestamp}.png"

    return fig_names_subopt, fig_name_pi_set_sizes, fig_name_mujoco


def dirs_and_loads(params):
    """
    Initialize experiment by setting & creating run directory and loading params and/or metrics.
    """

    run_dir = get_run_dir(params)
    return_params = params.copy()
    return_metrics = {}

    # if user specified a run_ID: see if we can load metrics and/or params
    if params["run_ID"] is not None:
        # if 'run_ID' directory exists: 'load' branch
        # check what user wants to do when loading (continue previous run, redo loaded run, or overwrite with current params)
        if os.path.exists(run_dir):
            # if continue: load metrics (-> sim only what's missing). None defaults to continue as it's not destructive
            if (
                params["loaded_run_behaviour"] == "continue"
                or params["loaded_run_behaviour"] is None
            ):
                pkl_files = [f for f in os.listdir(run_dir) if f.endswith(".pkl")]
                if len(pkl_files) == 0:
                    # raise FileNotFoundError(f"No .pkl files found in {run_dir}")  # TODO: fix
                    print(
                        "no pkl found even though run_ID specified and dir exists, so doing 'continue'"
                    )
                    pass
                if len(pkl_files) > 1:
                    raise RuntimeError(f"More than one .pkl file found in {run_dir}: {pkl_files}")
                if len(pkl_files) == 1:
                    with open(os.path.join(run_dir, pkl_files[0]), "rb") as f:
                        return_metrics = pickle.load(f)
                    print(f"Continuing loaded run: Loaded metrics from {run_dir}/{pkl_files[0]}")
            # if redo: load params (-> re-sim with loaded params)
            elif params["loaded_run_behaviour"] == "redo":
                with open(os.path.join(run_dir, "params.yaml"), "r") as f:
                    return_params = yaml.load(f, Loader=yaml.FullLoader)
                print(f"Redoing loaded run: Loaded params from {run_dir}/params.yaml")
                return_params["loaded_run_behaviour"] = "redo"  # relevant when saving metrics later
            # if overwrite: no loading (-> re-sim with current params)
            elif params["loaded_run_behaviour"] == "overwrite":
                print(f"Overwriting existing run: Not loading anything")
                pass
            else:
                print("Run already exists. Specify 'loaded_run_behaviour'.")
    # if no run_ID specified, or specified but dir doesn't exist yet: fresh run, create new run dir
    os.makedirs(run_dir, exist_ok=True)

    fig_names_subopt, fig_name_pi_set_sizes, fig_name_mujoco = get_fig_names(params)
    fig_paths_subopt = [f"{run_dir}/{fig_name}" for fig_name in fig_names_subopt]
    fig_path_pi_set_sizes = f"{run_dir}/{fig_name_pi_set_sizes}"
    fig_path_mujoco = f"{run_dir}/{fig_name_mujoco}"
    # if not saving results, don't save figs
    if not params["save_results"]:
        fig_paths_subopt = []
        fig_path_pi_set_sizes = None
        fig_path_mujoco = None
    metrics_path = f"{run_dir}/multi_metrics.pkl"
    return (
        run_dir,
        return_params,
        return_metrics,
        fig_paths_subopt,
        fig_path_pi_set_sizes,
        fig_path_mujoco,
        metrics_path,
    )


def save_metrics(multi_metrics, metrics_path, key, params):
    """this should save multi_metrics to metrics_path. saves run to existing dict if key not present, or saves to fresh file if no existing dict.
    This allows two separate processes that each do different runs of the same params to save to one metrics dict, for later plotting.

    Args:
        multi_metrics: dict of metrics, with keys being the run identifier (e.g. 'BRIDGE' or `purely_online') and values being a list of dicts (one dict per seed)
        i.e.: {'bridge': [{'metric1': [t1, t2, ...], 'metric2': [t1, t2, ...]}, {...}, ..., {...}], 'purely_online': [{..}, ..., {..}]}
        metrics_path: path to the metrics pkl file
        key: the run identifier (e.g. 'BRIDGE' or `purely_online')
        save_results: whether to save or skip.
    """
    if not params["save_results"]:
        return
    # if metrics file already exists, load it
    if os.path.exists(metrics_path):
        with open(metrics_path, "rb") as f:
            loaded_metrics = pickle.load(f)
        # if loaded metrics already contain the run you're trying to save, and you're not overwriting or redoing the run, abort
        if (
            key in loaded_metrics
            and params["loaded_run_behaviour"] != "overwrite"
            and params["loaded_run_behaviour"] != "redo"
        ):
            print(f"Key {key} already exists in {metrics_path}. Aborting save.")
            return
        # if it doesn't, add it and save the dict
        loaded_metrics[key] = multi_metrics[key]
        with open(metrics_path, "wb") as f:
            pickle.dump(loaded_metrics, f)
        print(f"Saved {key} to {metrics_path} (added to existing file)")
        return
    # if metrics file doesn't exist, save yours
    else:
        with open(metrics_path, "wb") as f:
            pickle.dump(multi_metrics, f)
        print(f"Saved {key} to {metrics_path} (fresh file)")
        return
