#!/usr/bin/env python3
from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Dict, List, Optional

import matplotlib.pyplot as plt


def load_json(path: Path) -> Optional[Dict]:
    if not path.exists():
        print(f"[skip] missing {path}")
        return None
    try:
        return json.loads(path.read_text())
    except json.JSONDecodeError as exc:
        print(f"[skip] failed to parse {path}: {exc}")
        return None


def save_fig(fig: plt.Figure, out_path: Path) -> None:
    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig.tight_layout()
    fig.savefig(out_path, dpi=200)
    plt.close(fig)
    print(f"[saved] {out_path}")


def plot_d3qn(real_path: Path, cf_path: Path, out_dir: Path) -> None:
    real = load_json(real_path) or {}
    cf = load_json(cf_path) or {}

    fig, ax = plt.subplots(1, 2, figsize=(12, 4))
    for name, data, color in [
        ("real", real, "tab:blue"),
        ("cf", cf, "tab:orange"),
    ]:
        if not data:
            continue
        ax[0].plot(data.get("total_losses", []), label=f"{name} total", color=color)
        ax[0].plot(
            data.get("td_losses", []),
            label=f"{name} td",
            color=color,
            alpha=0.5,
            linestyle="--",
        )
        evals = data.get("eval_returns") or []
        if evals:
            xs: List[float] = []
            ys: List[float] = []
            for i, e in enumerate(evals):
                if isinstance(e, dict):
                    xs.append(float(e.get("epoch", i + 1)))
                    ys.append(float(e.get("mean", 0.0)))
                elif isinstance(e, (list, tuple)) and len(e) >= 1:
                    try:
                        xs.append(float(i + 1))
                        ys.append(float(e[0]))
                    except (TypeError, ValueError):
                        continue
            if xs and ys:
                ax[1].plot(xs, ys, label=name, marker="o", color=color)
    ax[0].set_title("D3QN losses")
    ax[0].set_xlabel("epoch")
    ax[0].legend()
    ax[1].set_title("D3QN eval mean return")
    ax[1].set_xlabel("epoch")
    ax[1].legend()
    save_fig(fig, out_dir / "d3qn_overview.png")


def plot_bar_comparison(
    results: Dict[str, Path], out_dir: Path, eval_overrides: Dict[str, Path]
) -> None:
    labels: List[str] = []
    vals: List[float] = []
    for name, path in results.items():
        override_path = eval_overrides.get(name)
        if override_path and override_path.exists():
            override = load_json(override_path)
            if override and "mean_return" in override:
                labels.append(name)
                vals.append(float(override["mean_return"]))
                continue

        data = load_json(path)
        if not data:
            continue
        evals = data.get("eval_returns") or []
        if not evals:
            continue
        last = evals[-1]
        try:
            if isinstance(last, dict):
                vals.append(float(last.get("mean", 0.0)))
            else:
                vals.append(float(last[0]))
        except (TypeError, ValueError, IndexError):
            continue
        labels.append(name)
    if not labels:
        print("[skip] no eval data for bar chart")
        return
    fig, ax = plt.subplots(figsize=(6, 4))
    ax.bar(labels, vals, color="tab:green")
    ax.set_ylabel("Mean return")
    ax.set_title("Final eval comparison")
    save_fig(fig, out_dir / "comparison_bar.png")


def build_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        description="Generate plots for CTRL CartPole experiments.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--rainbow-metrics",
        type=Path,
        default=Path("results/cartpole/rainbow_reboot/20251208-001041/metrics.json"),
        help="Path to Rainbow metrics.json.",
    )
    parser.add_argument(
        "--d3qn-real",
        type=Path,
        default=Path("results/cartpole/d3qn_real/metrics.json"),
        help="Path to real-only D3QN metrics.json.",
    )
    parser.add_argument(
        "--d3qn-cf",
        type=Path,
        default=Path("results/cartpole/d3qn_cf/metrics.json"),
        help="Path to counterfactual D3QN metrics.json.",
    )
    parser.add_argument(
        "--output-dir",
        type=Path,
        default=Path("results/plots"),
        help="Directory to write png plots.",
    )
    parser.add_argument(
        "--eval-override",
        action="append",
        help="Override eval path for the bar chart (NAME=PATH).",
    )
    return parser


def main() -> None:
    args = build_parser().parse_args()
    out_dir = args.output_dir

    plot_d3qn(args.d3qn_real, args.d3qn_cf, out_dir)
    overrides: Dict[str, Path] = {
        "Rainbow": Path("results/eval/rainbow/20251208-003300/rainbow.json"),
        "D3QN real": Path("results/eval/d3qn/20251208-003243/d3qn.json"),
        "D3QN CF": Path("results/eval/d3qn/20251208-003251/d3qn.json"),
    }
    if args.eval_override:
        for entry in args.eval_override:
            if "=" not in entry:
                print(f"[skip] invalid override '{entry}' (expected NAME=PATH)")
                continue
            name, path_str = entry.split("=", 1)
            overrides[name.strip()] = Path(path_str.strip())
    plot_bar_comparison(
        {
            "Rainbow": args.rainbow_metrics,
            "D3QN real": args.d3qn_real,
            "D3QN CF": args.d3qn_cf,
        },
        out_dir,
        overrides,
    )


if __name__ == "__main__":
    main()
