"""Category definitions and helpers for AutoFormer supernet subspaces."""

from __future__ import annotations

import random
from dataclasses import dataclass
from typing import Dict, Iterable, List, Sequence, Tuple


@dataclass(frozen=True)
class SupernetSpec:
    hidden_dim: Sequence[int]
    depth: Sequence[int]
    num_heads: Sequence[int]
    mlp_ratio: Sequence[float]
    head_qkv_map: Dict[int, int]

    @property
    def qkv_values(self) -> List[int]:
        vals = sorted(set(self.head_qkv_map.values()))
        return vals


@dataclass(frozen=True)
class CategoryBounds:
    embed_dim: Tuple[int, int]
    depth: Tuple[int, int]
    num_heads: Tuple[int, int]
    mlp_ratio: Tuple[float, float]


def _make_range(start: int, stop: int, step: int) -> List[int]:
    return list(range(start, stop + step, step))


SUPERNET_SPECS: Dict[str, SupernetSpec] = {
    "tiny": SupernetSpec(
        hidden_dim=[192, 216, 240],
        depth=[12, 13, 14],
        num_heads=[3, 4],
        mlp_ratio=[3.5, 4.0],
        head_qkv_map={3: 192, 4: 256},
    ),
    "small": SupernetSpec(
        hidden_dim=[320, 384, 448],
        depth=[12, 13, 14],
        num_heads=[5, 6, 7],
        mlp_ratio=[3.0, 3.5, 4.0],
        head_qkv_map={5: 320, 6: 384, 7: 448},
    ),
    "base": SupernetSpec(
        hidden_dim=[528, 576, 624],
        depth=[14, 15, 16],
        num_heads=[8, 9, 10],
        mlp_ratio=[3.0, 3.5, 4.0],
        head_qkv_map={8: 512, 9: 576, 10: 640},
    ),
}

def default_categories(supernet: str) -> Dict[str, CategoryBounds]:
    """Return nine categories split by embed dimension and depth levels."""

    spec = get_supernet_spec(supernet)
    embed_levels = [spec.hidden_dim[0], spec.hidden_dim[len(spec.hidden_dim) // 2], spec.hidden_dim[-1]]
    depth_levels = [spec.depth[0], spec.depth[len(spec.depth) // 2], spec.depth[-1]]

    level_names = ["low", "mid", "high"]
    categories: Dict[str, CategoryBounds] = {}
    for i, embed_value in enumerate(embed_levels):
        for j, depth_value in enumerate(depth_levels):
            name = f"embed_{level_names[i]}_depth_{level_names[j]}"
            categories[name] = CategoryBounds(
                embed_dim=(embed_value, embed_value),
                depth=(depth_value, depth_value),
                num_heads=(spec.num_heads[0], spec.num_heads[-1]),
                mlp_ratio=(spec.mlp_ratio[0], spec.mlp_ratio[-1]),
            )
    return categories


def get_supernet_spec(name: str) -> SupernetSpec:
    key = name.lower()
    if key not in SUPERNET_SPECS:
        raise KeyError(f"Unsupported supernet '{name}'. Available: {list(SUPERNET_SPECS)}")
    return SUPERNET_SPECS[key]


def _values_within(bounds: Tuple, allowed: Sequence) -> List:
    lo, hi = bounds
    if lo > hi:
        lo, hi = hi, lo
    return [v for v in allowed if lo <= v <= hi]


def sample_arch_from_category(supernet: str, bounds: CategoryBounds) -> Dict:
    from .autoformer import make_arch_config

    spec = get_supernet_spec(supernet)

    hd_choices = _values_within(bounds.embed_dim, spec.hidden_dim)
    dp_choices = _values_within(bounds.depth, spec.depth)
    nh_choices = _values_within(bounds.num_heads, spec.num_heads)
    mr_choices = _values_within(bounds.mlp_ratio, spec.mlp_ratio)
    if not all([hd_choices, dp_choices, nh_choices, mr_choices]):
        raise ValueError(f"Bounds {bounds} produce empty choice set for supernet '{supernet}'.")

    hidden_dim = random.choice(hd_choices)
    depth = random.choice(dp_choices)
    num_heads = random.choice(nh_choices)
    mlp_ratio = random.choice(mr_choices)
    qkv_dim = spec.head_qkv_map[int(num_heads)]
    return make_arch_config(hidden_dim, depth, num_heads, mlp_ratio, qkv_dim=qkv_dim)


def clamp_arch_to_allowed(supernet: str, arch: Dict) -> Dict:
    from .autoformer import make_arch_config

    spec = get_supernet_spec(supernet)

    def clamp(val, allowed: Sequence):
        return min(allowed, key=lambda x: abs(x - val))

    hidden_dim = clamp(arch.get("hidden_dim"), spec.hidden_dim)
    depth = clamp(arch.get("depth"), spec.depth)
    num_heads = clamp(arch.get("num_heads", [spec.num_heads[0]])[0], spec.num_heads)
    mlp_ratio = clamp(arch.get("mlp_ratio", [spec.mlp_ratio[0]])[0], spec.mlp_ratio)
    qkv_dim = spec.head_qkv_map[int(num_heads)]
    return make_arch_config(hidden_dim, depth, num_heads, mlp_ratio, qkv_dim=qkv_dim)
