"""
Main entry point for preference-based reinforcement learning experiments.

This is the single, clean entry point that replaces the multiple main files.
All experiment logic is now organized in the core module.
"""

# %%
import os
import sys
from core import run_experiment_multiprocessing, print_config, parse_args
from utils import pad_metrics, dirs_and_loads, save_metrics
from plotting import (
    plot_suboptimalities_multimetrics,
    plot_pi_set_sizes_multimetrics,
    get_mle_policy_avg_reward,
)
import copy
import matplotlib.pyplot as plt
import yaml

import time


def main():
    """Main function for running experiments."""
    # get params from config. when called from CLI: use defaults augmented with user-specified args, or a config file the user points to with --config [file].
    params = {
        # environment and broad experiment params
        "env": "StarMDP_with_random_flinging",
        "N_experiments": 2,  # seeds. 3-30
        "N_iterations": 2,  # online iterations per seed. 2-15
        "episode_length": 8,  # 8 for StarMDP, 10 for Gridworld
        "env_move_prob": 0.7,  # 0.7 for StarMDP, 0.8 for Gridworld
        "phi_name": "id_short",  # options: state_counts, id_short, id_long, final_state
        "do_offline_BC": True,
        "N_offline_trajs": 2,  # starmdp: 2, gridworld: 10 (more and BC solves it)
        # offline learning
        "delta_offline": 0.05,
        "N_sampled_initial_policies": 100,  # for offline confset construction: (noise-matrices, or random generation in online_learning if no BC)
        "which_confset_construction_method": "noise-matrices",  # noise-matrices, rejection-sampling
        "which_hellinger_calc": "approx",  # options: exact (bhattacharyya), approx (local-avg)
        "n_transition_model_epochs_offline": 5,
        "offlineradius_formula": "hardcode_radius",  # options: full, ignore_bracket, only_alpha (formerly ignore_beta_in_confset_radius), hardcode_radius_scaled, hardcode_radius (formerly via providing float value to override_offlineradius)
        "offlineradius_override_value": 1,
        "replace_mle_with_optimal_policy_in_offline_confset": True,  # in offline confset, replace pi_MLE with pi_true
        # online learning
        "N_rollouts": 10,  # how many traj's sampled & added from (pi1, pi2)per iteration
        "delta_online": 0.05,
        "W": 1,  # weight bound
        "w_MLE_epochs": 10,  # number of epochs for w_MLE training
        "w_initialization": "uniform",  # options: uniform, random (uniform seems more stable)
        "w_sigmoid_slope": 10,
        "xi_formula": "smaller_start",  # full, smaller_start (both have same scaling in N, but smaller_start starts with smaller value)
        "n_transition_model_epochs_online": 5,  # number of epochs for transition model training
        "online_confset_recalc_phi": False,  # whether to use precomputed values or not (True means recomputing, slower)
        "online_confset_bonus_multiplier": 0.01,  # 0.01 for starMDP, 0.008 for gridworld. leave "1" for 'no multiplier'. (formerly multiply_bonus_inside_online_confset)
        "use_true_T_in_online": False,
        "gamma_t_hardcoded_value": 0.2,  # 0.2 for starMDP, 0.15 for gridworld. (formerly override_gamma_t)
        "baseline_search_space": "all_policies",  # options: all_policies, random_sample
        # verbosity
        "verbose": [],  # list, either [] or any combination of 'full', 'loop-summary', 'radius-calc', 'offline-confset', 'online-confset', 'warnings', 'losses'
        "run_baseline": True,
        "run_bridge": True,
        # saving
        "save_results": True,
        "run_ID": None,  # options: None (creates unique 3-digit ID), or string. If string, checks if dir exists -- if yes, loads & does what's specified in 'loaded_run_purpose', if no, runs new experiment.
        "loaded_run_behaviour": None,  # options: None (defaults to continue), "continue" (load metrics, sim what's missing, re-plot), "redo" (load params, re-sim, re-plot), "overwrite" (don't load anything, write to dir with current params)
        # plotting
        "which_plot_subopt": "cumulative_regret",  # or "suboptimality_percent" or "regret"
    }

    print("Running preference-based RL experiment with refactored codebase...")

    # if ran via CLI: parse args. if not (e.g. ran notebook-style in IDE), use the params dict above
    if len(sys.argv) > 1:
        params = parse_args()

    ## print config
    print_config(params)

    # vals = dirs_and_loads(params)
    # run_dir, params, multi_metrics, fig_path_subopt, fig_path_pi_set_sizes, _, metrics_path = vals
    vals = dirs_and_loads(params)
    (
        run_dir,
        params,
        multi_metrics,
        fig_paths_subopt,
        fig_path_pi_set_sizes,
        _,  # fig_path_mujoco
        metrics_path,
    ) = vals

    ## experiment runs. checks if run already in metrics (e.g. when they're loaded). if not, runs and optionally saves.
    # baseline
    if "purely_online" not in multi_metrics and params["run_baseline"]:
        print("%%%%% EXP: PURELY ONLINE %%%%%")
        params["do_offline_BC"] = False
        metrics_per_seed_purely_online, _, _ = run_experiment_multiprocessing(params)
        multi_metrics["purely_online"] = pad_metrics(metrics_per_seed_purely_online, params)
        save_metrics(multi_metrics, metrics_path, "purely_online", params)
    else:
        print(
            f"%%%%% EXP: PURELY ONLINE already in metrics, or run_baseline is False. skipping %%%%%"
        )

    # BRIDGE
    if "bridge" not in multi_metrics and params["run_bridge"]:
        print(f"%%%%% EXP: BRIDGE %%%%%")
        params["do_offline_BC"] = True
        metrics_per_seed_bridge, _, _ = run_experiment_multiprocessing(params)
        multi_metrics["bridge"] = pad_metrics(metrics_per_seed_bridge, params)
        save_metrics(multi_metrics, metrics_path, "bridge", params)
    else:
        print(f"%%%%% EXP: BRIDGE already in metrics, or run_bridge is False. skipping %%%%%")

    # get MLE policy performance as benchmark
    params_for_mle_policy = copy.deepcopy(params)
    params_for_mle_policy["N_offline_trajs"] = params[
        "N_offline_trajs"
    ]  # in case want to use a different pi_MLE
    mle_policy_avg_reward, opt_policy_avg_reward = get_mle_policy_avg_reward(
        params_for_mle_policy, N_seeds=10
    )

    # figure: suboptimalities
    plot_types = (
        [params["which_plot_subopt"]]
        if isinstance(params["which_plot_subopt"], str)
        else params["which_plot_subopt"]
    )
    figsize = (6, 3) if params["plot_slim"] else None
    for plot_type, fig_path_subopt in zip(plot_types, fig_paths_subopt):
        plt.figure()
        plot_suboptimalities_multimetrics(
            multi_metrics,
            fig_path_subopt,
            paper_style=True,
            mle_policy_avg_reward=mle_policy_avg_reward,
            opt_policy_avg_reward=opt_policy_avg_reward,
            which_plot=plot_type,
            figsize=figsize,
        )
        # plt.show()
        plt.close()

    # figure: pi set sizes
    plt.figure()
    plot_pi_set_sizes_multimetrics(
        multi_metrics,
        params,
        fig_path_pi_set_sizes,
        paper_style=True,
        figsize=figsize,
    )
    # plt.show()
    plt.close()

    print("about to save params")
    # (optionally) save params
    if params["save_results"]:
        # if running notebook-style: save params to run_dir/params.yaml
        params_path = os.path.join(run_dir, "params.yaml")
        with open(params_path, "w") as f:
            yaml.dump(params, f, default_flow_style=False)
        # for running from CLI: save params as yaml to configs/exps/run_ID.yaml
        run_id = run_dir.split("/")[-1]
        config_path = os.path.join("configs", "exps", f"{run_id}.yaml")
        os.makedirs(os.path.dirname(config_path), exist_ok=True)
        with open(config_path, "w") as f:
            yaml.dump(params, f, default_flow_style=False)

    print("run ended successfully")
    return 0


if __name__ == "__main__":
    start_time = time.time()
    main()
    end_time = time.time()
    print(f"Time taken: {end_time - start_time:.2f} seconds")
# %%
