#!/usr/bin/env python3
"""Generate a lightweight AutoFormer metrics dataset for quick experiments."""

from __future__ import annotations

import json
import os
from itertools import product
from typing import Dict, Iterable, List, Tuple

from vit_autoformer_nas.categories import SUPERNET_SPECS, default_categories, get_supernet_spec

OUTPUT_PATH = os.path.join(os.path.dirname(__file__), "autoformer_metrics.json")


def _normalize(value: float, values: Iterable[float]) -> float:
    values = list(values)
    lo = min(values)
    hi = max(values)
    if abs(hi - lo) < 1e-12:
        return 0.0
    return (value - lo) / (hi - lo)


def _resolve_category(supernet: str, embed_dim: int, depth: int) -> str:
    for name, bounds in default_categories(supernet).items():
        if bounds.embed_dim[0] == embed_dim and bounds.depth[0] == depth:
            return name
    return "unknown"


PARAM_WEIGHTS: Dict[str, Dict[str, float]] = {
    "tiny": {
        "acc_base": 0.45,
        "acc_embed": 0.25,
        "acc_qkv": 0.08,
        "acc_depth": 0.14,
        "acc_heads": 0.04,
        "acc_mlp": 0.04,
        "lat_base": 1.1,
        "lat_embed": 0.55,
        "lat_qkv": 0.35,
        "lat_depth": 0.45,
        "lat_heads": 0.28,
        "lat_mlp": 0.18,
    },
    "small": {
        "acc_base": 0.52,
        "acc_embed": 0.2,
        "acc_qkv": 0.08,
        "acc_depth": 0.16,
        "acc_heads": 0.04,
        "acc_mlp": 0.04,
        "lat_base": 2.0,
        "lat_embed": 0.75,
        "lat_qkv": 0.45,
        "lat_depth": 0.6,
        "lat_heads": 0.32,
        "lat_mlp": 0.22,
    },
    "base": {
        "acc_base": 0.60,
        "acc_embed": 0.18,
        "acc_qkv": 0.08,
        "acc_depth": 0.12,
        "acc_heads": 0.04,
        "acc_mlp": 0.04,
        "lat_base": 3.5,
        "lat_embed": 0.85,
        "lat_qkv": 0.5,
        "lat_depth": 0.65,
        "lat_heads": 0.35,
        "lat_mlp": 0.25,
    },
}

ACCURACY_RANGES: Dict[str, Tuple[float, float]] = {
    "tiny": (68.0, 78.0),
    "small": (70.0, 82.0),
    "base": (72.0, 84.0),
}


def generate_supernet_records(supernet: str) -> List[Dict]:
    spec = get_supernet_spec(supernet)
    weights = PARAM_WEIGHTS[supernet]
    acc_min, acc_max = ACCURACY_RANGES[supernet]
    records: List[Dict] = []

    for embed_dim, depth, num_heads, mlp_ratio in product(
        spec.hidden_dim,
        spec.depth,
        spec.num_heads,
        spec.mlp_ratio,
    ):
        qkv_dim = spec.head_qkv_map[num_heads]
        embed_norm = _normalize(embed_dim, spec.hidden_dim)
        qkv_norm = _normalize(qkv_dim, spec.qkv_values)
        depth_norm = _normalize(depth, spec.depth)
        heads_norm = _normalize(num_heads, spec.num_heads)
        mlp_norm = _normalize(mlp_ratio, spec.mlp_ratio)

        raw_acc = (
            weights["acc_base"]
            + weights["acc_embed"] * embed_norm
            + weights["acc_qkv"] * qkv_norm
            + weights["acc_depth"] * depth_norm
            + weights["acc_heads"] * heads_norm
            + weights["acc_mlp"] * mlp_norm
        )
        accuracy_norm = max(0.0, min(1.0, raw_acc))
        accuracy = acc_min + accuracy_norm * (acc_max - acc_min)

        latency_ms = (
            weights["lat_base"]
            + weights["lat_embed"] * embed_norm
            + weights["lat_qkv"] * qkv_norm
            + weights["lat_depth"] * depth_norm
            + weights["lat_heads"] * heads_norm
            + weights["lat_mlp"] * mlp_norm
        )

        records.append(
            {
                "hidden_dim": embed_dim,
                "qkv_dim": qkv_dim,
                "depth": depth,
                "num_heads": num_heads,
                "mlp_ratio": float(mlp_ratio),
                "accuracy_norm": round(accuracy_norm, 4),
                "accuracy": round(accuracy, 4),
                "latency_ms": round(latency_ms, 4),
                "category": _resolve_category(supernet, embed_dim, depth),
            }
        )

    def sort_key(entry):
        return (
            entry["hidden_dim"],
            entry.get("qkv_dim", 0),
            entry["depth"],
            entry["num_heads"],
            entry["mlp_ratio"],
        )

    records.sort(key=sort_key)
    return records


def main():
    payload = {
        "metadata": {
            "accuracy_ranges": ACCURACY_RANGES,
        },
        "supernets": {
            name: generate_supernet_records(name) for name in SUPERNET_SPECS.keys()
        },
    }
    with open(OUTPUT_PATH, "w") as handle:
        json.dump(payload, handle, indent=2)
    print(f"Wrote {OUTPUT_PATH}")


if __name__ == "__main__":
    main()
