"""
On 1 machine: run and plot N_offline_trajs ablation

Alternatively: do runs with separate main_mujoco calls and use ablation_N_offline.py to collate & plot (see doc of ablation_N_offline.py on HowTo)
"""

# %%
import sys
import os
import yaml

from experiment_runner_mujoco import (
    run_experiment_mp,
    preprocess_params_dict,
)
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt

import time
from utils.misc_helpers import dirs_and_loads, save_metrics
from utils.mujoco_helpers import (
    parse_args_mujoco,
    Tee,
    pad_metrics_mujoco,
)
from datetime import datetime
import gc
from ablation_N_offline import plot_N_offline_ablation


def main():
    params_dict = {}
    # TODO: add loading and CLI parsing
    main_time = time.time()
    if len(sys.argv) > 1:
        params_dict = parse_args_mujoco(args=None, base_config=params_dict)

    vals = dirs_and_loads(params_dict)
    (
        run_dir,
        params_dict,
        multi_metrics,
        fig_paths_subopt,
        fig_path_pi_set_sizes,
        fig_path_mujoco,
        metrics_path,
    ) = vals
    params_dict["run_dir"] = run_dir
    # logging
    pid = os.getpid()
    log_file_path = os.path.join(run_dir, f"log_{pid}_{datetime.now().strftime('%y%m%d-%H%M')}.txt")
    cli_call = " ".join(sys.argv)
    tee_output = Tee(log_file_path)
    original_stdout = sys.stdout
    sys.stdout = tee_output
    os.environ["EXPERIMENT_LOG_FILE"] = log_file_path

    print(f"cli call:\n{cli_call}")
    print(f"logging to: {log_file_path}")
    print(f"Starting experiment at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

    try:
        params, params_dict = preprocess_params_dict(params_dict)

        # ==== EXPERIMENT RUNS ====
        # do runs on bridge; noffline_trajs={10, 15, 20, 25, 30}, radius={1.1, 0.9, 0.65, 0.45, 0.4}

        noffline_trajs = [10, 15, 20, 25, 30]
        noffline_trajs_strs = [str(noffline_trajs) for noffline_trajs in noffline_trajs]
        radii = [1.1, 0.9, 0.65, 0.45, 0.4]
        radius_strs = [str(radius).replace(".", "_") for radius in radii]
        for noffline_trajs, noffline_trajs_str, radius, radius_str in zip(
            noffline_trajs, noffline_trajs_strs, radii, radius_strs
        ):
            exp_str = "bridge_N" + noffline_trajs_str + "_r" + radius_str
            if exp_str not in multi_metrics:
                params.noffline_trajs = noffline_trajs
                params.radius = radius
                params.baseline_or_bridge = "bridge"
                metrics_per_seed_bridge, _, _ = run_experiment_mp(params)
                multi_metrics[exp_str] = pad_metrics_mujoco(metrics_per_seed_bridge, params_dict)
                save_metrics(multi_metrics, metrics_path, exp_str, params_dict)
                plt.close("all")
                gc.collect()

        ## TODO: plotting func
        plot_N_offline_ablation(multi_metrics, max_N_offline=30)

        save_dir = "exps/ablations/" + run_dir[run_dir.index("/") + 1 :]

        os.makedirs(save_dir, exist_ok=True)
        pdf_path = os.path.join(save_dir, "plot.pdf")
        png_path = os.path.join(save_dir, "plot.png")
        plt.savefig(pdf_path, bbox_inches="tight")
        plt.savefig(png_path, bbox_inches="tight", dpi=300)
        print(f"Plots saved to:")
        print(f"  {pdf_path}")
        print(f"  {png_path}")
        print(f"Time taken: {time.time() - main_time:.2f} seconds")

    finally:
        sys.stdout = original_stdout
        tee_output.close()
        print(f"log file saved to {log_file_path}")
        # sys.exit()


if __name__ == "__main__":
    main()
