"""Optional LLM-driven proposal generation for ViT subnets."""

from __future__ import annotations

import json
from dataclasses import dataclass
from typing import Dict, List, Tuple

from .categories import SupernetSpec, clamp_arch_to_allowed, get_supernet_spec


@dataclass
class LLMConfig:
    per_category: int = 0
    model: str = "gpt-4o-mini"
    api_key: str = ""
    temperature: float = 0.2
    max_tokens: int = 800

    @property
    def enabled(self) -> bool:
        return self.per_category > 0 and bool(self.api_key)


def _robust_json_load(payload: str):
    try:
        return json.loads(payload)
    except Exception:
        if "```" in payload:
            for chunk in payload.split("```"):
                chunk = chunk.strip()
                if chunk.startswith("[") or chunk.startswith("{"):
                    try:
                        return json.loads(chunk)
                    except Exception:
                        continue
        start = payload.find("[")
        end = payload.rfind("]")
        if start != -1 and end != -1 and end > start:
            try:
                return json.loads(payload[start : end + 1])
            except Exception:
                pass
        return None


def llm_generate_arches(
    supernet: str,
    category: str,
    bounds: Tuple[Tuple[int, int], Tuple[int, int], Tuple[int, int], Tuple[float, float]],
    n: int,
    cfg: LLMConfig,
) -> List[Dict]:
    if not cfg.enabled:
        return []

    spec: SupernetSpec = get_supernet_spec(supernet)

    allowed = {
        "hidden_dim": list(spec.hidden_dim),
        "depth": list(spec.depth),
        "num_heads": list(spec.num_heads),
        "mlp_ratio": list(spec.mlp_ratio),
    }

    sys_prompt = (
        "You are a NAS expert generating Vision Transformer subnets for AutoFormerSub. "
        "Return ONLY valid JSON without commentary."
    )
    user_prompt = {
        "task": f"Generate {n} ViT candidate subnets for category {category}.",
        "supernet": supernet,
        "constraints": {
            key: {"allowed": vals, "preferred_range": list(bounds[idx])}
            for idx, (key, vals) in enumerate(allowed.items())
        },
        "format": {
            "type": "list[object]",
            "object_keys": ["hidden_dim", "depth", "num_heads", "mlp_ratio"],
            "notes": "num_heads and mlp_ratio may be scalars or length=depth lists",
        },
    }

    try:
        from openai import OpenAI

        client = OpenAI(api_key=cfg.api_key)
        response = client.chat.completions.create(
            model=cfg.model,
            messages=[
                {"role": "system", "content": sys_prompt},
                {"role": "user", "content": json.dumps(user_prompt)},
            ],
            temperature=float(cfg.temperature),
            max_tokens=int(cfg.max_tokens),
            response_format={"type": "json_object"},
        )
    except Exception:
        return []

    payload = response.choices[0].message.content.strip()
    data = _robust_json_load(payload)
    if not isinstance(data, list):
        return []

    arches: List[Dict] = []
    for entry in data[:n]:
        try:
            arch = {
                "hidden_dim": int(entry["hidden_dim"]),
                "depth": int(entry["depth"]),
                "num_heads": entry.get("num_heads"),
                "mlp_ratio": entry.get("mlp_ratio"),
            }
            arches.append(clamp_arch_to_allowed(supernet, arch))
        except Exception:
            continue
    return arches
