#!/usr/bin/env python3
"""Generate aggregated proxy plots for AutoFormer supernets."""

from __future__ import annotations

import argparse
import json
import os
from typing import Dict, List, Tuple

import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader

from .autoformer import build_autoformer_model, ensure_autoprox_on_path, make_arch_config
from .categories import SUPERNET_SPECS
from .datasets import RandomImageDataset
from .metrics_dataset import DEFAULT_DATASET_PATH
from .pareto import pareto_front
from .zero_cost import (
    aggregate_scores,
    build_measure_stats,
    compute_zero_cost_scores,
    orient_measure,
)

MEASURES = ("jacov", "jacobian_trace", "grad_norm")
SUPERNET_COLORS = {
    "tiny": "#1f77b4",
    "small": "#ff7f0e",
    "base": "#2ca02c",
}


def load_metrics(path: str) -> Tuple[Dict[str, List[Dict]], Dict[str, Tuple[float, float]]]:
    with open(path, "r") as handle:
        payload = json.load(handle)
    return payload.get("supernets", {}), payload.get("metadata", {}).get("accuracy_ranges", {})


def compute_proxy_scores(
    supernet: str,
    entries: List[Dict],
    ap_root: str,
    measures: Tuple[str, ...],
) -> List[Dict]:
    num_classes = 100  # CIFAR-100 proxy
    img_size = 32
    dataset = RandomImageDataset(num_samples=64, img_size=img_size, num_classes=num_classes)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=False)

    records: List[Dict] = []
    oriented_records: List[Dict[str, float]] = []

    for entry in entries:
        arch = make_arch_config(
            hidden_dim=entry["hidden_dim"],
            depth=entry["depth"],
            num_heads=entry["num_heads"],
            mlp_ratio=entry["mlp_ratio"],
            qkv_dim=entry.get("qkv_dim"),
        )
        model = build_autoformer_model(ap_root, arch, dataset="cifar100").cpu()
        with torch.no_grad():
            scores = compute_zero_cost_scores(model, dataloader, num_classes=num_classes, measures=measures)
        oriented = {name: orient_measure(name, val) for name, val in scores.items()}
        oriented_records.append(oriented)
        records.append(
            {
                "arch": entry,
                "latency_ms": entry["latency_ms"],
                "accuracy": entry["accuracy"],
                "measures": scores,
                "measures_oriented": oriented,
            }
        )

    stats = build_measure_stats(oriented_records)
    for record in records:
        agg = aggregate_scores(record["measures_oriented"], stats)
        record["proxy_score"] = agg
    return records


def normalize_proxy(records: List[Dict], acc_bounds: Tuple[float, float]) -> None:
    scores = [r.get("proxy_score") for r in records if r.get("proxy_score") is not None]
    if not scores:
        return
    s_min = min(scores)
    s_max = max(scores)
    span = s_max - s_min if s_max > s_min else 1.0
    acc_min, acc_max = acc_bounds
    for record in records:
        proxy = record.get("proxy_score")
        if proxy is None:
            record["proxy_accuracy"] = None
            continue
        norm = (proxy - s_min) / span
        record["proxy_norm"] = norm
        record["proxy_accuracy"] = acc_min + norm * (acc_max - acc_min)


def plot_combined(
    supernet_data: Dict[str, List[Dict]],
    output: str,
) -> None:
    fig, ax = plt.subplots(figsize=(9.0, 6.0))
    legend_seen = set()

    for supernet, records in supernet_data.items():
        color = SUPERNET_COLORS.get(supernet, "#888888")
        xs = [r["latency_ms"] for r in records]
        ys = [r["accuracy"] for r in records]
        label = f"{supernet.title()} search space"
        if label not in legend_seen:
            ax.scatter(xs, ys, s=14, c=color, alpha=0.25, label=label)
            legend_seen.add(label)
        else:
            ax.scatter(xs, ys, s=14, c=color, alpha=0.25)

        actual_front = pareto_front(
            [{"latency_ms": r["latency_ms"], "acc": r["accuracy"]} for r in records],
            acc_key="acc",
            lat_key="latency_ms",
        )
        ax.plot(
            [p["latency_ms"] for p in actual_front],
            [p["acc"] for p in actual_front],
            "-",
            color=color,
            linewidth=2.0,
            label=f"{supernet.title()} actual front",
        )

        proxy_points = [r for r in records if r.get("proxy_accuracy") is not None]
        if proxy_points:
            proxy_front = pareto_front(
                [
                    {"latency_ms": r["latency_ms"], "proxy_acc": r["proxy_accuracy"]}
                    for r in proxy_points
                ],
                acc_key="proxy_acc",
                lat_key="latency_ms",
            )
            ax.plot(
                [p["latency_ms"] for p in proxy_front],
                [p["proxy_acc"] for p in proxy_front],
                "--",
                color=color,
                linewidth=1.8,
                label=f"{supernet.title()} proxy front",
            )

    ax.set_xlabel("Latency (ms, batch=1)")
    ax.set_ylabel("Accuracy (%)")
    ax.set_title("AutoFormer Supernets: Search Space vs. Proxy Fronts")
    ax.grid(True, linestyle="--", linewidth=0.3, alpha=0.4)
    ax.legend(loc="lower left", fontsize=8)
    fig.tight_layout()
    os.makedirs(os.path.dirname(output), exist_ok=True)
    fig.savefig(output, dpi=300)
    fig.savefig(os.path.splitext(output)[0] + ".pdf")
    plt.close(fig)
    print(f"Wrote {output}")


def plot_triptych(
    supernet_data: Dict[str, List[Dict]],
    output: str,
) -> None:
    active_supernets = [name for name in SUPERNET_SPECS.keys() if name in supernet_data]
    if not active_supernets:
        raise ValueError("No supernet data provided")

    fig, axes = plt.subplots(1, len(active_supernets), figsize=(16.0, 4.8), sharey=True)
    if len(active_supernets) == 1:
        axes = [axes]

    handles = []
    labels = []

    for ax, supernet in zip(axes, active_supernets):
        records = supernet_data[supernet]
        color = SUPERNET_COLORS.get(supernet, "#888888")
        xs = [r["latency_ms"] for r in records]
        ys = [r["accuracy"] for r in records]
        scatter = ax.scatter(xs, ys, s=20, c=color, alpha=0.35, label="Search space")

        actual_front = pareto_front(
            [{"latency_ms": r["latency_ms"], "accuracy": r["accuracy"]} for r in records],
            acc_key="accuracy",
            lat_key="latency_ms",
        )
        actual_line, = ax.plot(
            [p["latency_ms"] for p in actual_front],
            [p["accuracy"] for p in actual_front],
            "-",
            color=color,
            linewidth=2.2,
            label="Actual Pareto",
        )

        proxy_points = [r for r in records if r.get("proxy_accuracy") is not None]
        if proxy_points:
            proxy_front = pareto_front(
                [
                    {"latency_ms": r["latency_ms"], "proxy_acc": r["proxy_accuracy"]}
                    for r in proxy_points
                ],
                acc_key="proxy_acc",
                lat_key="latency_ms",
            )
            proxy_line, = ax.plot(
                [p["latency_ms"] for p in proxy_front],
                [p["proxy_acc"] for p in proxy_front],
                "--",
                color=color,
                linewidth=2.0,
                label="Proxy Pareto",
            )
        else:
            proxy_line = None

        ax.set_title(f"{supernet.title()} Supernet")
        ax.set_xlabel("Latency (ms)")
        ax.grid(True, linestyle="--", linewidth=0.3, alpha=0.35)

        if not handles:
            handles.extend([scatter, actual_line])
            labels.extend(["Search space", "Actual Pareto"])
            if proxy_line is not None:
                handles.append(proxy_line)
                labels.append("Proxy Pareto")

    axes[0].set_ylabel("Accuracy (%)")
    fig.legend(handles, labels, loc="lower center", ncol=len(handles))
    fig.tight_layout(rect=(0, 0.1, 1, 1))
    os.makedirs(os.path.dirname(output), exist_ok=True)
    fig.savefig(output, dpi=300)
    fig.savefig(os.path.splitext(output)[0] + ".pdf")
    plt.close(fig)
    print(f"Wrote {output}")


def main(argv=None) -> None:
    parser = argparse.ArgumentParser(description="Plot proxy performance across all supernets")
    parser.add_argument("--metrics-json", default=DEFAULT_DATASET_PATH, help="path to autoformer_metrics.json")
    parser.add_argument("--output", default="outputs/vit_autoformer_nas/global_proxy.png", help="output image path")
    parser.add_argument("--measures", default="jacov,jacobian_trace,grad_norm", help="comma separated zero-cost measures")
    parser.add_argument("--mode", choices=["combined", "triptych"], default="combined", help="plot layout")
    args = parser.parse_args(argv)

    metrics, accuracy_ranges = load_metrics(args.metrics_json)
    measures = tuple([m.strip() for m in args.measures.split(",") if m.strip()]) or MEASURES

    project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
    ap_root = ensure_autoprox_on_path(project_root)

    combined_records: Dict[str, List[Dict]] = {}
    for supernet, entries in metrics.items():
        print(f"Processing {supernet} ({len(entries)} architectures)...")
        records = compute_proxy_scores(supernet, entries, ap_root, measures)
        acc_bounds = accuracy_ranges.get(supernet)
        if acc_bounds is None:
            accs = [entry["accuracy"] for entry in entries]
            acc_bounds = (min(accs), max(accs))
        normalize_proxy(records, acc_bounds)
        combined_records[supernet] = records

    if args.mode == "combined":
        plot_combined(combined_records, args.output)
    else:
        plot_triptych(combined_records, args.output)


if __name__ == "__main__":
    main()

