#!/usr/bin/env python3
"""Plot Tiny/Small/Base search results with mapped accuracy values."""

from __future__ import annotations

import argparse
import json
import os
from typing import Dict, List, Optional, Tuple

import matplotlib.pyplot as plt

from .pareto import pareto_front

SUPERNET_ORDER = ["tiny", "small", "base"]
FRIENDLY_NAMES = {"tiny": "Supernet-Tiny", "small": "Supernet-Small", "base": "Supernet-Base"}
ORIGIN_STYLE = {
    "llm": {"color": "#1f77b4", "label": "LLM candidates", "marker": "o", "size": 28, "alpha": 0.85},
    "random": {"color": "#ff7f0e", "label": "Random samples", "marker": "^", "size": 24, "alpha": 0.75},
}

__all__ = ["generate_supernet_comparison"]


def load_metrics(metrics_json: str) -> Tuple[Dict[str, List[Dict]], Dict[str, Tuple[float, float]]]:
    with open(metrics_json, "r") as handle:
        data = json.load(handle)
    return data.get("supernets", {}), data.get("metadata", {}).get("accuracy_ranges", {})


def load_results(path: str) -> Tuple[List[Dict], str, Optional[Tuple[float, float]]]:
    with open(path, "r") as handle:
        payload = json.load(handle)
    bounds = payload.get("accuracy_bounds")
    if bounds is not None:
        bounds = tuple(bounds)
    return payload.get("all", []), payload.get("y_key", "acc_proxy_norm"), bounds


def select_points(points: List[Dict], y_key: str) -> List[Dict]:
    usable: List[Dict] = []
    for entry in points:
        y_val = entry.get(y_key)
        x_val = entry.get("latency_ms")
        if y_val is None or x_val is None:
            continue
        if isinstance(y_val, float) and (y_val != y_val):  # NaN
            continue
        usable.append(entry)
    return usable


def ensure_accuracy(points: List[Dict], y_key: str, bounds: Optional[Tuple[float, float]]) -> Tuple[List[Dict], str]:
    if y_key == "est_acc":
        return points, y_key
    if not bounds:
        return points, y_key
    acc_min, acc_max = bounds
    if acc_max <= acc_min:
        return points, y_key
    span = acc_max - acc_min
    for entry in points:
        norm_val = entry.get(y_key)
        if norm_val is None:
            continue
        entry["est_acc"] = acc_min + norm_val * span
    return points, "est_acc"


def compute_front(points: List[Dict], y_key: str) -> Tuple[List[float], List[float]]:
    front = pareto_front(points, acc_key=y_key, lat_key="latency_ms")
    xs = [item["latency_ms"] for item in front]
    ys = [item[y_key] for item in front]
    return xs, ys


def generate_supernet_comparison(
    metrics_json: str,
    supernet_results: Dict[str, str],
    output_path: str,
    figsize: Tuple[float, float] = (14.0, 4.2),
) -> str:
    metrics, accuracy_ranges = load_metrics(metrics_json)
    active_supernets = [name for name in SUPERNET_ORDER if supernet_results.get(name)]
    if not active_supernets:
        raise ValueError("No results provided for comparison plot")

    fig, axes = plt.subplots(1, len(active_supernets), figsize=figsize, sharey=True)
    if len(active_supernets) == 1:
        axes = [axes]

    for ax, supernet in zip(axes, active_supernets):
        metrics_points = metrics.get(supernet, [])
        xs_space = [record["latency_ms"] for record in metrics_points]
        range_bounds = accuracy_ranges.get(supernet)
        if range_bounds and range_bounds[1] > range_bounds[0]:
            span = range_bounds[1] - range_bounds[0]
            ys_space = [record.get("accuracy", range_bounds[0] + record.get("accuracy_norm", 0.0) * span) for record in metrics_points]
        else:
            ys_space = [record.get("accuracy", record.get("accuracy_norm", 0.0)) for record in metrics_points]
        ax.scatter(xs_space, ys_space, c="#cccccc", s=12, alpha=0.35, label="Search space")

        results, y_key, bounds = load_results(supernet_results[supernet])
        if bounds is None and supernet in accuracy_ranges:
            bounds = tuple(accuracy_ranges[supernet])
        usable = select_points(results, y_key)
        usable, y_key = ensure_accuracy(usable, y_key, bounds)

        legend_added = set()
        for origin, style in ORIGIN_STYLE.items():
            subset = [p for p in usable if p.get("origin") == origin]
            if not subset:
                continue
            ax.scatter(
                [p["latency_ms"] for p in subset],
                [p[y_key] for p in subset],
                c=style["color"],
                s=style["size"],
                alpha=style["alpha"],
                marker=style["marker"],
                label=style["label"] if style["label"] not in legend_added else None,
            )
            legend_added.add(style["label"])

        if metrics_points:
            xs_front, ys_front = compute_front(
                [
                    {"latency_ms": record["latency_ms"], "est_acc": record.get("accuracy", 0.0)}
                    for record in metrics_points
                ],
                "est_acc",
            )
            ax.plot(xs_front, ys_front, "-", color="#d62728", linewidth=2.0, label="Search-space Pareto")

        llm_only = select_points([p for p in usable if p.get("origin") == "llm"], y_key)
        if llm_only:
            xs_llm, ys_llm = compute_front(llm_only, y_key)
            ax.plot(xs_llm, ys_llm, "--", color="#1f77b4", linewidth=2.0, label="LLM frontier")

        ax.set_title(FRIENDLY_NAMES.get(supernet, supernet))
        ax.set_xlabel("Latency (ms)")

    axes[0].set_ylabel("Accuracy (%)")
    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc="lower center", ncol=min(4, len(labels)))
    fig.tight_layout(rect=(0, 0.08, 1, 1))

    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    fig.savefig(output_path, dpi=300)
    fig.savefig(os.path.splitext(output_path)[0] + ".pdf")
    print(f"Wrote {output_path}")
    return output_path


def main(argv=None) -> None:
    parser = argparse.ArgumentParser(description="Plot Tiny/Small/Base AutoFormer search comparison")
    parser.add_argument("--metrics-json", default="vit_autoformer_nas/data/autoformer_metrics.json", help="precomputed metrics JSON")
    parser.add_argument("--tiny-results", type=str, default=None, help="results.json for Supernet-Tiny")
    parser.add_argument("--small-results", type=str, default=None, help="results.json for Supernet-Small")
    parser.add_argument("--base-results", type=str, default=None, help="results.json for Supernet-Base")
    parser.add_argument("--output", type=str, default="outputs/vit_autoformer_nas/supernet_comparison.png", help="output figure path")
    parser.add_argument("--figsize", type=float, nargs=2, default=(14.0, 4.2), help="figure size (width height)")
    args = parser.parse_args(argv)

    results_map = {
        "tiny": args.tiny_results,
        "small": args.small_results,
        "base": args.base_results,
    }
    results_map = {k: v for k, v in results_map.items() if v}
    generate_supernet_comparison(
        metrics_json=args.metrics_json,
        supernet_results=results_map,
        output_path=args.output,
        figsize=tuple(args.figsize),
    )


if __name__ == "__main__":
    main()
