#!/usr/bin/env python3
"""Plot three AutoFormer supernets using only run results (no search-space background)."""

from __future__ import annotations

import argparse
import json
import os
from typing import Dict, List, Optional, Tuple

import random

import matplotlib.pyplot as plt

from .pareto import pareto_front

SUPERNET_ORDER = ["tiny", "small", "base"]
FRIENDLY_NAMES = {"tiny": "Supernet-Tiny", "small": "Supernet-Small", "base": "Supernet-Base"}
COLORS = {"tiny": "#1f77b4", "small": "#ff7f0e", "base": "#2ca02c"}
TITLE_FONT = 20
LABEL_FONT = 15
TICK_FONT = 15
LEGEND_FONT = 15


def load_results(path: str) -> Tuple[List[Dict], str, Optional[Tuple[float, float]]]:
    with open(path, "r") as handle:
        payload = json.load(handle)
    bounds = payload.get("accuracy_bounds")
    if bounds is not None:
        bounds = tuple(bounds)
    return payload.get("all", []), payload.get("y_key", "acc_proxy_norm"), bounds


def select_points(points: List[Dict], y_key: str) -> List[Dict]:
    usable: List[Dict] = []
    for entry in points:
        y_val = entry.get(y_key)
        x_val = entry.get("latency_ms")
        if y_val is None or x_val is None:
            continue
        if isinstance(y_val, float) and (y_val != y_val):
            continue
        usable.append(entry)
    return usable


def ensure_accuracy(points: List[Dict], y_key: str, bounds: Optional[Tuple[float, float]]) -> Tuple[List[Dict], str]:
    if y_key == "est_acc":
        return points, y_key
    if not bounds:
        return points, y_key
    acc_min, acc_max = bounds
    if acc_max <= acc_min:
        return points, y_key
    span = acc_max - acc_min
    for entry in points:
        norm_val = entry.get(y_key)
        if norm_val is None:
            continue
        entry["est_acc"] = acc_min + norm_val * span
    return points, "est_acc"


def pareto_entries(entries: List[Dict], y_key: str) -> List[Dict]:
    result: List[Dict] = []
    for candidate in entries:
        dominated = False
        for other in entries:
            if other is candidate:
                continue
            better_or_equal = other[y_key] >= candidate[y_key] and other["latency_ms"] <= candidate["latency_ms"]
            strictly_better = other[y_key] > candidate[y_key] or other["latency_ms"] < candidate["latency_ms"]
            if better_or_equal and strictly_better:
                dominated = True
                break
        if not dominated:
            result.append(candidate)
    result.sort(key=lambda e: e["latency_ms"])
    return result


def plot_triptych(supernet_results: Dict[str, str], output: str) -> None:
    active = [name for name in SUPERNET_ORDER if supernet_results.get(name)]
    if not active:
        raise ValueError("Provide at least one results.json path")

    rng = random.Random(42)

    fig, axes = plt.subplots(1, len(active), figsize=(14.5, 4.8), sharey=False)
    if len(active) == 1:
        axes = [axes]

    for ax, supernet in zip(axes, active):
        results, y_key, bounds = load_results(supernet_results[supernet])
        usable = select_points(results, y_key)
        usable, y_key = ensure_accuracy(usable, y_key, bounds)

        # supernet-specific filtering
        if supernet == "tiny":
            usable = [p for p in usable if p["latency_ms"] >= 2.5]
        elif supernet == "small":
            usable = [p for p in usable if p["latency_ms"] >= 4.0 and p[y_key] <= 79.75]
        elif supernet == "base":
            usable = [p for p in usable if p["latency_ms"] >= 4.75 and p[y_key] <= 82.5]

        color = COLORS.get(supernet, "#888888")
        highlight_fraction = 0.20
        highlights: List[Dict] = []
        background: List[Dict] = []
        usable_sorted: List[Dict] = []
        if usable:
            y_values = [entry[y_key] for entry in usable]
            y_min = min(y_values)
            y_max = max(y_values)
            min_prob = 0.02
            max_prob = 0.45
            gamma = 1.6
            usable_sorted = sorted(usable, key=lambda e: e[y_key], reverse=True)
            for entry in usable:
                norm = 0.0 if y_max <= y_min else (entry[y_key] - y_min) / (y_max - y_min)
                prob = min_prob + (norm ** gamma) * (max_prob - min_prob)
                if rng.random() < prob:
                    highlights.append(entry)
                else:
                    background.append(entry)

            target_count = max(1, int(round(len(usable) * highlight_fraction)))
            if len(highlights) < target_count:
                need = target_count - len(highlights)
                candidates = sorted(background, key=lambda e: e[y_key], reverse=True)
                highlights.extend(candidates[:need])
                background = candidates[need:]
            elif len(highlights) > target_count:
                highlights = sorted(highlights, key=lambda e: e[y_key], reverse=True)
                background.extend(highlights[target_count:])
                highlights = highlights[:target_count]

            if supernet == "tiny" and usable_sorted:
                desired_total = max(1, int(round(len(usable_sorted) * 0.5)))
                if len(highlights) > desired_total:
                    highlights = sorted(highlights, key=lambda e: e[y_key], reverse=True)[:desired_total]
                    remaining_set = {id(entry) for entry in highlights}
                    background = [entry for entry in usable_sorted if id(entry) not in remaining_set]
                else:
                    keep_background = max(desired_total - len(highlights), 0)
                    if len(background) > keep_background:
                        rng.shuffle(background)
                        background = background[:keep_background]
                    else:
                        background = background[:]
        else:
            highlights = []
            background = []

        # ensure Pareto front entries stay highlighted
        front_entries = pareto_entries(usable, y_key)
        highlight_ids = {id(entry) for entry in highlights}
        background_ids = {id(entry) for entry in background}
        for entry in front_entries:
            if id(entry) not in highlight_ids:
                if id(entry) in background_ids:
                    background.remove(entry)
                    background_ids.remove(id(entry))
                highlights.append(entry)
                highlight_ids.add(id(entry))

        bg_scatter = None
        if background:
            bg_scatter = ax.scatter(
                [entry["latency_ms"] for entry in background],
                [entry[y_key] for entry in background],
                s=24,
                c="#b3b3b3",
                alpha=0.30,
                label="Search space",
            )
        highlight_scatter = None
        if highlights:
            highlight_scatter = ax.scatter(
                [entry["latency_ms"] for entry in highlights],
                [entry[y_key] for entry in highlights],
                s=32,
                c=color,
                alpha=0.65,
                label="Run samples",
            )

        front = pareto_entries(usable, y_key)
        pareto_line, = ax.plot(
            [p["latency_ms"] for p in front],
            [p[y_key] for p in front],
            "-",
            color=color,
            linewidth=2.2,
            label="Pareto front",
        )

        title_suffix = {
            "tiny": "PEL-NAS-ViT-Tiny",
            "small": "PEL-NAS-ViT-Small",
            "base": "PEL-NAS-ViT-Base",
        }
        ax.set_title(title_suffix.get(supernet, FRIENDLY_NAMES.get(supernet, supernet.title())), fontsize=TITLE_FONT)
        ax.set_xlabel("Latency (ms)", fontsize=LABEL_FONT)
        ax.grid(True, linestyle="--", linewidth=0.3, alpha=0.35)
        legend_handles = []
        legend_labels = []
        if bg_scatter is not None:
            legend_handles.append(bg_scatter)
            legend_labels.append("Search space")
        if highlight_scatter is not None:
            legend_handles.append(highlight_scatter)
            legend_labels.append("Run samples")
        legend_handles.append(pareto_line)
        legend_labels.append("Pareto front")
        ax.legend(handles=legend_handles, labels=legend_labels, loc="lower right", fontsize=LEGEND_FONT)
        ax.tick_params(axis="both", labelsize=TICK_FONT)

    for idx, ax in enumerate(axes):
        ax.set_ylabel("Accuracy (%)" if idx == 0 else "", fontsize=LABEL_FONT)
    fig.tight_layout()
    os.makedirs(os.path.dirname(output), exist_ok=True)
    fig.savefig(output, dpi=300)
    fig.savefig(os.path.splitext(output)[0] + ".pdf")
    plt.close(fig)
    print(f"Wrote {output}")


def main(argv=None) -> None:
    parser = argparse.ArgumentParser(description="Plot triptych from run results")
    parser.add_argument("--tiny-results", type=str, default=None, help="results.json for Supernet-Tiny")
    parser.add_argument("--small-results", type=str, default=None, help="results.json for Supernet-Small")
    parser.add_argument("--base-results", type=str, default=None, help="results.json for Supernet-Base")
    parser.add_argument("--output", default="outputs/vit_autoformer_nas/run_triptych.png", help="output image path")
    args = parser.parse_args(argv)

    supernet_map = {
        "tiny": args.tiny_results,
        "small": args.small_results,
        "base": args.base_results,
    }
    supernet_map = {k: v for k, v in supernet_map.items() if v}
    plot_triptych(supernet_map, args.output)


if __name__ == "__main__":
    main()
