"""Pre-computed latency and accuracy proxies for AutoFormer supernets."""

from __future__ import annotations

import json
import os
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

DEFAULT_DATASET_PATH = os.path.join(os.path.dirname(__file__), "data", "autoformer_metrics.json")


def _to_signature(arch: Dict) -> Tuple[int, int, int, int, float]:
    hidden_dim = int(arch["hidden_dim"])
    qkv_dim = int(arch.get("qkv_dim", hidden_dim))
    depth = int(arch["depth"])
    num_heads_field = arch.get("num_heads")
    if isinstance(num_heads_field, (list, tuple)):
        num_heads = int(num_heads_field[0])
    else:
        num_heads = int(num_heads_field)
    mlp_field = arch.get("mlp_ratio")
    if isinstance(mlp_field, (list, tuple)):
        mlp_ratio = float(mlp_field[0])
    else:
        mlp_ratio = float(mlp_field)
    return (hidden_dim, qkv_dim, depth, num_heads, mlp_ratio)


@dataclass
class MetricRecord:
    accuracy_norm: float
    accuracy: float
    latency_ms: float
    category: str


class PrecomputedMetricsRepository:
    """Simple lookup for generated metrics."""

    def __init__(self, path: Optional[str] = None, strict: bool = False):
        dataset_path = path or DEFAULT_DATASET_PATH
        self.path = dataset_path
        self._data: Dict[str, Dict[Tuple[int, int, int, int, float], MetricRecord]] = {}
        self._accuracy_bounds: Dict[str, Tuple[float, float]] = {}

        if not os.path.isfile(dataset_path):
            if strict:
                raise FileNotFoundError(f"Precomputed metrics file not found: {dataset_path}")
            return

        with open(dataset_path, "r") as handle:
            payload = json.load(handle)

        supernets = payload.get("supernets", {})
        for supernet, entries in supernets.items():
            store: Dict[Tuple[int, int, int, int, float], MetricRecord] = {}
            accuracies: List[float] = []
            for entry in entries:
                signature = (
                    int(entry["hidden_dim"]),
                    int(entry.get("qkv_dim", entry["hidden_dim"])),
                    int(entry["depth"]),
                    int(entry["num_heads"]),
                    float(entry["mlp_ratio"]),
                )
                acc_norm = float(entry.get("accuracy_norm", 0.0))
                acc_val = float(entry.get("accuracy", acc_norm))
                store[signature] = MetricRecord(
                    accuracy_norm=acc_norm,
                    accuracy=acc_val,
                    latency_ms=float(entry["latency_ms"]),
                    category=str(entry.get("category", "")),
                )
                accuracies.append(acc_val)
            if store:
                self._data[supernet] = store
                self._accuracy_bounds[supernet] = (min(accuracies), max(accuracies))

    def has_metrics(self, supernet: str, arch: Dict) -> bool:
        if supernet not in self._data:
            return False
        signature = _to_signature(arch)
        return signature in self._data[supernet]

    def get_metrics(self, supernet: str, arch: Dict) -> Optional[MetricRecord]:
        if supernet not in self._data:
            return None
        return self._data[supernet].get(_to_signature(arch))

    def iter_metrics(self, supernet: str) -> List[MetricRecord]:
        return list(self._data.get(supernet, {}).values())

    def get_accuracy_bounds(self, supernet: str) -> Optional[Tuple[float, float]]:
        return self._accuracy_bounds.get(supernet)


__all__ = [
    "DEFAULT_DATASET_PATH",
    "MetricRecord",
    "PrecomputedMetricsRepository",
    "_to_signature",
]
