"""Pareto utilities for latency vs. proxy accuracy."""

from __future__ import annotations

import os
from typing import Dict, Iterable, List, Optional, Tuple


def pareto_front(points: Iterable[Dict], acc_key: str, lat_key: str) -> List[Dict]:
    def dominated(candidate, other) -> bool:
        acc_a = candidate[acc_key]
        acc_b = other[acc_key]
        lat_a = candidate[lat_key]
        lat_b = other[lat_key]
        better_or_equal = acc_b >= acc_a and lat_b <= lat_a
        strictly_better = acc_b > acc_a or lat_b < lat_a
        return better_or_equal and strictly_better

    non_dom = []
    points_list = list(points)
    for point in points_list:
        if any(dominated(point, other) for other in points_list if other is not point):
            continue
        non_dom.append(point)
    non_dom.sort(key=lambda item: item[lat_key])
    return non_dom


def save_pareto_plot(
    all_points: List[Dict],
    frontier: List[Dict],
    output_png: str,
    acc_key: str,
    lat_key: str,
    title: str,
    y_limits: Optional[Tuple[float, float]] = None,
    y_label: Optional[str] = None,
    background_points: Optional[List[Dict]] = None,
) -> None:
    try:
        import matplotlib.pyplot as plt
    except ModuleNotFoundError:
        print("matplotlib not available; skipping Pareto plot generation.")
        return

    os.makedirs(os.path.dirname(output_png), exist_ok=True)
    plt.figure(figsize=(6, 4))
    if background_points:
        plt.scatter(
            [bp["latency_ms"] for bp in background_points],
            [bp["y"] for bp in background_points],
            s=10,
            c="#d0d0d0",
            alpha=0.35,
            label="Search space",
        )

    origin_styles = {
        "llm": {"color": "#1f77b4", "label": "LLM candidates", "marker": "o", "size": 26, "alpha": 0.85},
        "random": {"color": "#ff7f0e", "label": "Random samples", "marker": "^", "size": 22, "alpha": 0.75},
    }

    plotted_labels = set()
    for point in all_points:
        origin = point.get("origin")
        style = origin_styles.get(origin)
        if style is None:
            style = {"color": "#999999", "label": None, "marker": "s", "size": 18, "alpha": 0.6}
        label = style["label"]
        if label in plotted_labels or label is None:
            label = None
        else:
            plotted_labels.add(style["label"])
        plt.scatter(
            point[lat_key],
            point[acc_key],
            c=style["color"],
            s=style["size"],
            alpha=style["alpha"],
            marker=style["marker"],
            label=label,
        )

    xs_f = [pt[lat_key] for pt in frontier]
    ys_f = [pt[acc_key] for pt in frontier]
    plt.plot(xs_f, ys_f, "-o", c="#d62728", label="Pareto Front", markersize=4)

    llm_points = [pt for pt in all_points if pt.get("origin") == "llm"]
    if llm_points:
        llm_front = pareto_front(llm_points, acc_key=acc_key, lat_key=lat_key)
        if len(llm_front) > 1:
            plt.plot(
                [pt[lat_key] for pt in llm_front],
                [pt[acc_key] for pt in llm_front],
                "--",
                c="#1f77b4",
                linewidth=2.0,
                label="LLM Frontier",
            )

    plt.xlabel("Latency (ms, batch=1)")
    default_ylabel = "Accuracy (%)" if acc_key == "est_acc" else "Accuracy Proxy (higher is better)"
    plt.ylabel(y_label or default_ylabel)
    plt.title(title)
    if y_limits is not None:
        plt.ylim(y_limits[0], y_limits[1])
    plt.legend()
    plt.tight_layout()
    plt.savefig(output_png, dpi=200)
    plt.close()
