from __future__ import annotations

import argparse
import pickle
from pathlib import Path
from typing import Any, Dict

from plot import plot_regret, plot_times


def _load_summary(path: Path) -> Dict[str, Any]:
    with path.open("rb") as fh:
        return pickle.load(fh)


def _maybe_plot(summary: Dict[str, Any], stem: str, *, plot_running_time: bool) -> None:
    avg_regret = summary.get("avg_regret")
    std_regret = summary.get("std_regret", {})
    if not avg_regret:
        return
    horizon = max((len(values) for values in avg_regret.values()), default=0)
    plot_regret(avg_regret, std_regret, horizon, exp_name=stem)
    if plot_running_time:
        timings = summary.get("timings")
        if timings:
            plot_times(timings, exp_name=stem)


def _print_summary(summary: Dict[str, Any], label: str) -> None:
    avg_regret = summary.get("avg_regret", {})
    if not avg_regret:
        print(f"{label}: no regret data found.")
        return
    final_regrets = {
        alg: float(values[-1])
        for alg, values in avg_regret.items()
        if len(values) > 0
    }
    print(f"{label}: final regrets -> {final_regrets}")


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Visualise results stored in run_experiments.py pickle outputs."
    )
    parser.add_argument("results", nargs="+", type=Path, help="Result .pkl files.")
    parser.add_argument(
        "--no-plot",
        action="store_true",
        help="Skip plotting, only print final regret values.",
    )
    parser.add_argument(
        "--plot-times",
        action="store_true",
        help="Plot the average running times alongside regret.",
    )
    parser.add_argument(
        "--label",
        help="Override plot label stem (defaults to individual file stem).",
    )
    args = parser.parse_args()

    for path in args.results:
        summary = _load_summary(path)
        stem = args.label or path.stem
        _print_summary(summary, stem)
        if not args.no_plot:
            _maybe_plot(summary, stem, plot_running_time=args.plot_times)


if __name__ == "__main__":
    main()
