"""
Main entry point for preference-based reinforcement learning experiments.

This is the single, clean entry point that replaces the multiple main files.
All experiment logic is now organized in the core module.
"""

# %%
import sys
import os
import yaml

from experiment_runner_mujoco import (
    run_experiment_mp,
    run_experiment_sp,
    preprocess_params_dict,
)
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt

import time
from plotting.plotting_common import (
    plot_suboptimalities_multimetrics,
    plot_pi_set_sizes_multimetrics,
    plot_mujoco_multimetrics,
)
from utils.misc_helpers import dirs_and_loads, save_metrics
from utils.mujoco_helpers import (
    parse_args_mujoco,
    Tee,
    pad_metrics_mujoco,
)
from datetime import datetime
import gc


def main():
    params_dict = {
        # environment and broad experiment params
        "env_id": "Reacher-v5",  # "HalfCheetah-v5", "Reacher-v5"
        "seed": 42,  # initial seed
        "N_experiments": 5,  # seeds. 3-30
        "N_iterations": 100,  # online iterations per seed. 2-15
        "episode_length": 50,  # halfcheetah: 100
        "embedding_name": "avg_sa",  # options: avg_sa , avg_s, last_s, actionenergy, psm, reacher_perf, halfcheetah_xpos
        "N_offline_trajs": 50,  # halfcheetah: 10 (more and BC solves it)
        "fresh_offline_trajs": False,  # if True, will generate new offline trajs even if they already exist
        "initial_pos_noise": None,  # only HalfCheetah. default is 0.1
        # offline learning
        "N_confset_size": 100,
        "confset_noise": None,  # noise added to BC policy to generate confset. if not provided, filled in w/ environment defaults: Reacher: 0.05, HalfCheetah: ???
        "n_bc_epochs": 100,
        "bc_loss": "log-loss",  # "mse" , "log-loss" (tabular BRIDGE)
        "bc_print_evals": False,
        "radius": None,  # if unspecified, uses hardcoded defaults per embedding. for filtering offline confset: L2(embed(π_BC) - embed(π_candidate)) < radius
        # online learning
        "N_rollouts": 10,  # how many trajectories to sample & annotate per online loop
        "filter_pi_t_yesno": False,
        "filter_pi_t_gamma": 1,
        "gamma_debug_mode": False,
        "W": 1,
        "w_trainfunc": "mle",  # "rebuttals" (no batching), "mle" (tabular BRIDGE)
        "w_regularization": None,  # None , "l2" (tabular BRIDGE)
        "w_epochs": 100,  # 100 , 10 (tabular BRIDGE)
        "w_initialization": "uniform",  # "zeros" , "uniform" (tabular BRIDGE)
        "w_sigmoid_slope": 1,  # 1 , 10 (tabular BRIDGE)
        "project_w": False,
        "retrain_w_from_scratch": False,
        "which_policy_selection": "random",  # "ucb", "random" , "max_uncertainty"
        "ucb_beta": 1,
        "V_init": "small",  # "small" , "bounds" (BRIDGE)
        "n_embedding_samples": 200,
        # policy model params
        "hidden_dim": None,  # policy hidden dim. for now, this isn't used, instead we use the SB3 expert's value hardcoded via RLZoo hparams in training_cfg. This is reacher: 64, halfcheetah: 256 (all 2-layers) (SB3 would default to 64 x2).
        # verbosity
        "verbose": [],  # list, either [] or any combination of 'full', 'loop-summary', 'radius-calc', 'offline-confset', 'online-confset', 'warnings', 'losses'
        "run_baseline": True,
        "run_bridge": True,
        "save_results": True,
        "run_ID": None,  # options: None (creates unique 3-digit ID), or string. If string, checks if dir exists -- if yes, loads & does what's specified in 'loaded_run_purpose', if no, runs new experiment.
        "loaded_run_behaviour": None,  # options: None (defaults to continue), "continue" (load metrics, sim what's missing, re-plot), "redo" (load params, re-sim, re-plot), "overwrite" (don't load anything, write to dir with current params)
        "which_plot_subopt": "cumulative_regret",  # "suboptimality_percent" or "regret" or "cumulative_regret"
        "baseline_or_bridge": None,  # "baseline", "bridge" for single runs
        "plot_scores": False,
        "exclude_outliers": False,  # exclude runs based on cumulative_regret_T outliers: "worst_{bcexpertdist, cumregret}", "95conf_{bcexpertdist, cumregret}" (exclude runs outside 95% conf estimate (mean+1.96*std))
    }

    # TODO: add loading and CLI parsing
    main_time = time.time()
    if len(sys.argv) > 1:
        params_dict = parse_args_mujoco(args=None, base_config=params_dict)

    vals = dirs_and_loads(params_dict)
    (
        run_dir,
        params_dict,
        multi_metrics,
        fig_paths_subopt,
        fig_path_pi_set_sizes,
        fig_path_mujoco,
        metrics_path,
    ) = vals
    params_dict["run_dir"] = run_dir
    # logging
    pid = os.getpid()
    log_file_path = os.path.join(run_dir, f"log_{pid}_{datetime.now().strftime('%y%m%d-%H%M')}.txt")
    cli_call = " ".join(sys.argv)
    tee_output = Tee(log_file_path)
    original_stdout = sys.stdout
    sys.stdout = tee_output
    os.environ["EXPERIMENT_LOG_FILE"] = log_file_path

    print(f"cli call:\n{cli_call}")
    print(f"logging to: {log_file_path}")
    print(f"Starting experiment at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

    try:
        params, params_dict = preprocess_params_dict(params_dict)

        # ==== EXPERIMENT RUNS ====
        if "bridge" not in multi_metrics and params.run_bridge:
            print(f"=== BRIDGE ===")
            params.baseline_or_bridge = "bridge"
            metrics_per_seed_bridge, avg_expert_reward, avg_bc_reward = run_experiment_mp(params)
            multi_metrics["avg_expert_reward"] = avg_expert_reward
            multi_metrics["avg_bc_reward"] = avg_bc_reward
            multi_metrics["bridge"] = pad_metrics_mujoco(metrics_per_seed_bridge, params_dict)
            save_metrics(multi_metrics, metrics_path, "bridge", params_dict)
            plt.close("all")
            gc.collect()
            # this saves avg expert & bc rewards b/c when no metrics exist, we save the entire dict to a fresh file (so either for baseline or bridge)

        if "baseline" not in multi_metrics and params.run_baseline:
            print(f"\n=== BASELINE ===")
            params.baseline_or_bridge = "baseline"
            metrics_per_seed_baseline, avg_expert_reward, avg_bc_reward = run_experiment_mp(params)
            multi_metrics["avg_expert_reward"] = avg_expert_reward
            multi_metrics["avg_bc_reward"] = avg_bc_reward
            multi_metrics["baseline"] = pad_metrics_mujoco(metrics_per_seed_baseline, params_dict)
            save_metrics(multi_metrics, metrics_path, "baseline", params_dict)
            plt.close("all")
            gc.collect()

        # # in case we're just replotting & loading from metrics -> this doesnt work bc multi_metrics[key] is a list of dicts (one per seed)
        # available_key = "bridge" if "bridge" in multi_metrics else "baseline"
        # avg_bc_reward = multi_metrics[available_key]["avg_bc_reward"]
        # avg_expert_reward = multi_metrics[available_key]["avg_expert_reward"]

        # ==== PLOTTING ====
        plot_types = (
            [params.which_plot_subopt]
            if isinstance(params.which_plot_subopt, str)
            else params.which_plot_subopt
        )
        figsize = (6, 3) if params.plot_slim else None
        for plot_type, fig_path_subopt in zip(plot_types, fig_paths_subopt):
            # plt.figure()
            plot_suboptimalities_multimetrics(
                multi_metrics,
                fig_path_subopt,  # TODO: insert figure path
                paper_style=True,
                mle_policy_avg_reward=multi_metrics["avg_bc_reward"],
                opt_policy_avg_reward=multi_metrics["avg_expert_reward"],
                which_plot=plot_type,  # "suboptimality_percent" or "regret" or "cumulative_regret",
                exclude_outliers=params.exclude_outliers,
                figsize=(6, 3),
            )
            plt.close()
        plt.close("all")

        # plt.figure()
        plot_pi_set_sizes_multimetrics(
            multi_metrics,
            params_dict,
            fig_path_pi_set_sizes,
            paper_style=True,
            exclude_outliers=params.exclude_outliers,
            figsize=(6, 3),
        )
        plt.close()

        if params.env_id in [
            "Reacher-v5",
            "HalfCheetah-v5",
            "Walker2d-v5",
            "Hopper-v5",
            "Ant-v5",
            "Humanoid-v5",
        ]:
            plt.figure()
            plot_mujoco_multimetrics(
                multi_metrics,
                fig_path_mujoco,
                paper_style=True,
                exclude_outliers=params.exclude_outliers,
            )
            plt.close()

        # ==== SAVING PARAMS ====
        print("about to save params")
        # (optionally) save params
        if params.save_results:
            # if running notebook-style: save params to run_dir/params.yaml
            params_path = os.path.join(run_dir, "params.yaml")
            with open(params_path, "w") as f:
                yaml.dump(params_dict, f, default_flow_style=False)
            # for running from CLI: save params as yaml to configs/exps/{env}/{run_ID}.yaml
            run_id = run_dir.split("/")[-1]
            config_path = os.path.join("configs", "exps", f"{params.env_id}", f"{run_id}.yaml")
            os.makedirs(os.path.dirname(config_path), exist_ok=True)
            with open(config_path, "w") as f:
                yaml.dump(params_dict, f, default_flow_style=False)

        print(
            f"run '{run_id}' ended successfully at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
        )
        print(f"Time taken: {time.time() - main_time:.2f} seconds")

    finally:
        sys.stdout = original_stdout
        tee_output.close()
        print(f"log file saved to {log_file_path}")
        # sys.exit()


if __name__ == "__main__":
    main()
