#!/usr/bin/env python3
"""Plot three AutoFormer supernets using only metrics dataset accuracy."""

from __future__ import annotations

import argparse
import json
import os
from typing import Dict, List, Tuple

import matplotlib.pyplot as plt

from .pareto import pareto_front

SUPERNET_ORDER = ["tiny", "small", "base"]
FRIENDLY_NAMES = {"tiny": "Supernet-Tiny", "small": "Supernet-Small", "base": "Supernet-Base"}
COLORS = {"tiny": "#1f77b4", "small": "#ff7f0e", "base": "#2ca02c"}


def load_metrics(path: str) -> Dict[str, List[Dict]]:
    with open(path, "r") as handle:
        payload = json.load(handle)
    return payload.get("supernets", {})


def plot_triptych(metrics: Dict[str, List[Dict]], output: str) -> None:
    active = [name for name in SUPERNET_ORDER if metrics.get(name)]
    if not active:
        raise ValueError("No supernet data found in metrics JSON")

    fig, axes = plt.subplots(1, len(active), figsize=(16.0, 4.8), sharey=True)
    if len(active) == 1:
        axes = [axes]

    for ax, supernet in zip(axes, active):
        entries = metrics[supernet]
        color = COLORS.get(supernet, "#888888")
        xs = [entry["latency_ms"] for entry in entries]
        ys = [entry["accuracy"] for entry in entries]
        ax.scatter(xs, ys, s=20, c=color, alpha=0.35)

        front = pareto_front(
            [{"latency_ms": entry["latency_ms"], "accuracy": entry["accuracy"]} for entry in entries],
            acc_key="accuracy",
            lat_key="latency_ms",
        )
        ax.plot(
            [p["latency_ms"] for p in front],
            [p["accuracy"] for p in front],
            "-",
            color=color,
            linewidth=2.2,
            label="Pareto front",
        )

        ax.set_title(FRIENDLY_NAMES.get(supernet, supernet.title()))
        ax.set_xlabel("Latency (ms)")
        ax.grid(True, linestyle="--", linewidth=0.3, alpha=0.35)

    axes[0].set_ylabel("Accuracy (%)")
    fig.legend(loc="lower center", ncol=3)
    fig.tight_layout(rect=(0, 0.08, 1, 1))
    os.makedirs(os.path.dirname(output), exist_ok=True)
    fig.savefig(output, dpi=300)
    fig.savefig(os.path.splitext(output)[0] + ".pdf")
    plt.close(fig)
    print(f"Wrote {output}")


def main(argv=None) -> None:
    parser = argparse.ArgumentParser(description="Plot AutoFormer metrics triptych")
    parser.add_argument(
        "--metrics-json",
        default="vit_autoformer_nas/data/autoformer_metrics.json",
        help="path to autoformer_metrics.json",
    )
    parser.add_argument(
        "--output",
        default="outputs/vit_autoformer_nas/metrics_triptych.png",
        help="output image path",
    )
    args = parser.parse_args(argv)

    metrics = load_metrics(args.metrics_json)
    plot_triptych(metrics, args.output)


if __name__ == "__main__":
    main()

