"""
On 1 machine: run and plot radius ablation

Alternatively: do runs with separate main_mujoco calls and use ablation_radii.py to collate & plot (see doc of ablation_radii.py on HowTo)
"""

# %%
import sys
import os
import yaml

from experiment_runner_mujoco import (
    run_experiment_mp,
    preprocess_params_dict,
)
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt

import time
from utils.misc_helpers import dirs_and_loads, save_metrics
from utils.mujoco_helpers import (
    parse_args_mujoco,
    Tee,
    pad_metrics_mujoco,
)
from datetime import datetime
import gc
from ablation_radii import plot_radius_ablation


def main():
    params_dict = {}
    # TODO: add loading and CLI parsing
    main_time = time.time()
    if len(sys.argv) > 1:
        params_dict = parse_args_mujoco(args=None, base_config=params_dict)

    vals = dirs_and_loads(params_dict)
    (
        run_dir,
        params_dict,
        multi_metrics,
        fig_paths_subopt,
        fig_path_pi_set_sizes,
        fig_path_mujoco,
        metrics_path,
    ) = vals
    params_dict["run_dir"] = run_dir
    # logging
    pid = os.getpid()
    log_file_path = os.path.join(run_dir, f"log_{pid}_{datetime.now().strftime('%y%m%d-%H%M')}.txt")
    cli_call = " ".join(sys.argv)
    tee_output = Tee(log_file_path)
    original_stdout = sys.stdout
    sys.stdout = tee_output
    os.environ["EXPERIMENT_LOG_FILE"] = log_file_path

    print(f"cli call:\n{cli_call}")
    print(f"logging to: {log_file_path}")
    print(f"Starting experiment at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

    try:
        params, params_dict = preprocess_params_dict(params_dict)

        # ==== EXPERIMENT RUNS ====
        # do runs on baseline; radius={0.3, 0.65, 1.5, 3, 5, 10}

        # baseline
        if "baseline" not in multi_metrics and params.run_baseline:
            print("Running baseline experiment")
            params.baseline_or_bridge = "baseline"
            metrics_per_seed_baseline, avg_expert_reward, avg_bc_reward = run_experiment_mp(params)
            multi_metrics["avg_expert_reward"] = avg_expert_reward
            multi_metrics["avg_bc_reward"] = avg_bc_reward
            multi_metrics["baseline"] = pad_metrics_mujoco(metrics_per_seed_baseline, params_dict)
            save_metrics(multi_metrics, metrics_path, "baseline", params_dict)
            plt.close("all")
            gc.collect()

        radii = [0.3, 0.65, 1.5, 3, 5, 10]
        radius_strs = [str(radius).replace(".", "_") for radius in radii]
        for radius, radius_str in zip(radii, radius_strs):
            exp_str = "bridge_" + radius_str
            if exp_str not in multi_metrics and params.run_bridge:
                print(f"Running BRIDGE experiment {exp_str}")
                params.radius = radius
                params.baseline_or_bridge = "bridge"
                metrics_per_seed_bridge, _, _ = run_experiment_mp(params)
                multi_metrics[exp_str] = pad_metrics_mujoco(metrics_per_seed_bridge, params_dict)
                save_metrics(multi_metrics, metrics_path, exp_str, params_dict)
                plt.close("all")
                gc.collect()

        ## TODO: plotting func
        plot_radius_ablation(multi_metrics, max_radius=10, only_bridge=False)

        save_dir = "exps/ablations/" + run_dir[run_dir.index("/") + 1 :]

        os.makedirs(save_dir, exist_ok=True)
        pdf_path = os.path.join(save_dir, "plot.pdf")
        png_path = os.path.join(save_dir, "plot.png")
        plt.savefig(pdf_path, bbox_inches="tight")
        plt.savefig(png_path, bbox_inches="tight", dpi=300)
        print(f"Plots saved to:")
        print(f"  {pdf_path}")
        print(f"  {png_path}")
        print(f"Time taken: {time.time() - main_time:.2f} seconds")

    finally:
        sys.stdout = original_stdout
        tee_output.close()
        print(f"log file saved to {log_file_path}")
        # sys.exit()


if __name__ == "__main__":
    main()
