"""Zero-cost proxy helpers using Auto-Prox utilities."""

from __future__ import annotations

import math
from typing import Any, Dict, Iterable, List, Optional


def compute_zero_cost_scores(
    model: Any,
    dataloader: Any,
    num_classes: int,
    measures: Iterable[str],
) -> Dict[str, Optional[float]]:
    import torch
    from pycls.predictor.pruners.predictive import find_measures

    device = "cuda" if torch.cuda.is_available() else "cpu"
    data_info = ("random", 1, int(num_classes))

    scores: Dict[str, Optional[float]] = {}
    for name in measures:
        try:
            result = find_measures(
                model,
                dataloader,
                data_info,
                device,
                torch.nn.functional.cross_entropy,
                measure_names=[name],
            )
            scores[name] = float(result)
        except Exception:
            scores[name] = None
    return scores


def orient_measure(name: str, value: Optional[float]) -> Optional[float]:
    if value is None:
        return None
    if isinstance(value, float) and (math.isnan(value) or math.isinf(value)):
        return None
    lower_name = (name or "").lower()
    if lower_name in {"jacov", "logits_entropy", "entropy"}:
        return -float(value)
    return float(value)


def build_measure_stats(records: List[Dict[str, float]], percentile_clip: float = 5.0) -> Dict[str, tuple]:
    try:
        import numpy as np  # type: ignore
    except ModuleNotFoundError:
        np = None

    per_measure: Dict[str, List[float]] = {}
    for record in records:
        for key, val in record.items():
            if val is None:
                continue
            per_measure.setdefault(key, []).append(val)
    stats: Dict[str, tuple] = {}
    for name, values in per_measure.items():
        if not values:
            continue
        if np is not None:
            lo, hi = np.percentile(values, [percentile_clip, 100.0 - percentile_clip])
        else:
            sorted_vals = sorted(values)
            lo_idx = int(max(0, round((percentile_clip / 100.0) * (len(sorted_vals) - 1))))
            hi_idx = int(min(len(sorted_vals) - 1, round(((100.0 - percentile_clip) / 100.0) * (len(sorted_vals) - 1))))
            lo, hi = sorted_vals[lo_idx], sorted_vals[hi_idx]
        if hi - lo < 1e-9:
            lo, hi = min(values), max(values) + 1e-9
        stats[name] = (float(lo), float(hi))
    return stats


def aggregate_scores(
    oriented_scores: Dict[str, Optional[float]],
    stats: Dict[str, tuple],
) -> Optional[float]:
    components: List[float] = []
    for name, bounds in stats.items():
        value = oriented_scores.get(name)
        if value is None:
            continue
        lo, hi = bounds
        clipped = max(min(value, hi), lo)
        components.append((clipped - lo) / (hi - lo))
    if not components:
        return None
    return float(sum(components) / len(components))
