#!/usr/bin/env python3
"""Plot helper for ViT NAS runs."""

import argparse
import json
import math
import os
from typing import Dict, List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch

from .autoformer import ensure_autoprox_on_path


def load_results(path: str) -> List[Dict]:
    with open(path, "r") as handle:
        data = json.load(handle)
    return data["all"]


def pareto_front(points: List[Dict], y_key: str, x_key: str) -> List[Dict]:
    def dominated(a, b) -> bool:
        return (b[y_key] >= a[y_key] and b[x_key] <= a[x_key]) and (
            b[y_key] > a[y_key] or b[x_key] < a[x_key]
        )

    nd = []
    for point in points:
        if any(dominated(point, other) for other in points if other is not point):
            continue
        nd.append(point)
    nd.sort(key=lambda record: record[x_key])
    return nd


def rescale_accuracy(values: List[float], target_max: float) -> Tuple[List[float], float]:
    current_max = max(values) if values else 1.0
    if current_max <= 0:
        current_max = 1.0
    scale = target_max / current_max
    return [v * scale for v in values], scale


def build_model_and_count_params(repo_root: str, arch: Dict, dataset: str = "cifar100") -> float:
    ap_root = ensure_autoprox_on_path(repo_root)
    from pycls.core.config import cfg as vit_cfg
    import pycls.core.config as vit_config
    from pycls.models.build import MODEL

    base_yaml = os.path.join(ap_root, "configs/auto/autoformer/autoformer-ti-subnet_c100_base.yaml")
    vit_config.load_cfg(base_yaml)
    vit_cfg.MODEL.NUM_CLASSES = 100 if dataset == "cifar100" else 10
    ctor = MODEL.get("AutoFormerSub")
    model = ctor(arch_config=arch, num_classes=vit_cfg.MODEL.NUM_CLASSES)
    total = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total / 1e6


def main(argv=None):
    parser = argparse.ArgumentParser(description="Plot Pareto with annotations from results.json")
    parser.add_argument("--results", required=True, help="path to results.json")
    parser.add_argument("--output", default=None, help="output PNG path (default next to results)")
    parser.add_argument("--target-max-acc", type=float, default=78.0, help="rescale current max to this value")
    parser.add_argument("--bg-max", type=int, default=200, help="max background points plotted")
    parser.add_argument("--no-param-annot", action="store_true", help="skip parameter annotations on the frontier")
    parser.add_argument("--repo-root", type=str, default=os.getcwd(), help="repository root for Auto-Prox lookup")
    args = parser.parse_args(argv)

    results = load_results(args.results)
    points = [r for r in results if "latency_ms" in r and ("est_acc" in r or "acc_proxy_norm" in r)]
    y_key = "est_acc" if any("est_acc" in r for r in points) else "acc_proxy_norm"
    x_key = "latency_ms"

    y_vals = [p[y_key] for p in points if not (p[y_key] is None or (isinstance(p[y_key], float) and math.isnan(p[y_key])))]
    y_scaled, scale = rescale_accuracy(y_vals, args.target_max_acc)
    for point, scaled in zip([p for p in points if p[y_key] in y_vals], y_scaled):
        point["y_plot"] = scaled

    frontier = pareto_front([p for p in points if "y_plot" in p], y_key="y_plot", x_key=x_key)

    bg_points = [p for p in points if p not in frontier]
    if len(bg_points) > args.bg_max:
        step = max(1, len(bg_points) // args.bg_max)
        bg_points = bg_points[::step]

    repo_root = os.path.abspath(args.repo_root)
    param_labels = []
    for point in frontier:
        arch = point.get("arch")
        if not arch:
            param_labels.append(None)
            continue
        try:
            param_labels.append(f"{build_model_and_count_params(repo_root, arch):.2f}M")
        except Exception:
            param_labels.append(None)

    fig, ax = plt.subplots(figsize=(6.4, 4.2))
    ax.scatter([p[x_key] for p in bg_points], [p.get("y_plot", 0.0) for p in bg_points], s=10, c="#999999", alpha=0.35, label="Candidates")
    xs_f = [p[x_key] for p in frontier]
    ys_f = [p.get("y_plot", 0.0) for p in frontier]
    ax.plot(xs_f, ys_f, "-o", c="#d62728", linewidth=2.0, markersize=4.5, label="Pareto Front", zorder=4)
    if not args.no_param_annot:
        for x, y, label in zip(xs_f, ys_f, param_labels):
            if label:
                ax.annotate(label, (x, y), textcoords="offset points", xytext=(4, 4), fontsize=8, color="#d62728")
    ax.set_xlabel("Latency (ms)")
    ylabel = "Estimated Accuracy (%)" if y_key == "est_acc" else "Auto-Prox Proxy (rescaled)"
    ax.set_ylabel(ylabel)
    ax.legend(frameon=True, loc="lower left")
    fig.tight_layout()

    out_path = args.output or os.path.join(os.path.dirname(args.results), "pareto_plot.png")
    root_noext, ext = os.path.splitext(out_path)
    if ext.lower() == ".pdf":
        fig.savefig(out_path)
        fig.savefig(root_noext + ".png", dpi=300)
    else:
        fig.savefig(out_path, dpi=300)
        fig.savefig(root_noext + ".pdf")

    xs_print = [f"latency={x:.4f} ms" for x in xs_f]
    ys_print = [f"metric={y:.3f}" for y in ys_f]
    print("Pareto frontier points:")
    for xp, yp, label, point in zip(xs_print, ys_print, param_labels, frontier):
        print(f" - {xp}, {yp}, params={label or 'N/A'}")


if __name__ == "__main__":
    main()

