"""Core search loop for ViT NAS experiments."""

from __future__ import annotations

import csv
import json
import math
import os
import random
import time
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

from .autoformer import build_autoformer_model, ensure_autoprox_on_path
from .categories import CategoryBounds, default_categories, sample_arch_from_category
from .llm import LLMConfig, llm_generate_arches
from .metrics_dataset import DEFAULT_DATASET_PATH, PrecomputedMetricsRepository
from .pareto import pareto_front, save_pareto_plot
from .zero_cost import (
    aggregate_scores,
    build_measure_stats,
    compute_zero_cost_scores,
    orient_measure,
)


@dataclass
class SearchOutputs:
    output_dir: str
    results: List[Dict]
    pareto: List[Dict]
    y_key: str


@dataclass
class SearchSettings:
    repo_root: str
    dataset: str = "cifar100"
    supernet: str = "tiny"
    per_category: int = 6
    iterations: int = 1
    measures: Tuple[str, ...] = ("jacov", "jacobian_trace")
    primary_measure: str = "jacov"
    seed: int = 42
    use_cuda: bool = True
    cpu_threads: Optional[int] = None
    lat_warmup: int = 8
    lat_iters: int = 40
    lat_burn: int = 8
    zc_samples: int = 64
    zc_batch: int = 32
    output_root: str = "outputs/vit_autoformer_nas"
    acc_min: Optional[float] = None
    acc_max: Optional[float] = None
    acc_margin: float = 0.8
    free_y: bool = False
    cfg_override: Optional[str] = None
    metrics_path: Optional[str] = DEFAULT_DATASET_PATH


def _configure_threads(num_threads: Optional[int], torch_module=None) -> None:
    if not num_threads or num_threads <= 0:
        return
    val = str(num_threads)
    os.environ["OMP_NUM_THREADS"] = val
    os.environ["MKL_NUM_THREADS"] = val
    os.environ["NUMEXPR_NUM_THREADS"] = val
    if torch_module is None:
        return
    try:
        torch_module.set_num_threads(num_threads)
        torch_module.set_num_interop_threads(max(1, num_threads // 2))
    except Exception:
        pass


def _prepare_llm_config(llm_cfg: LLMConfig) -> LLMConfig:
    if not llm_cfg.enabled:
        return llm_cfg
    try:
        import openai  # noqa: F401
    except Exception:
        return LLMConfig(per_category=0)
    return llm_cfg


def measure_latency_ms(
    model,
    img_size: int,
    batch_size: int = 1,
    warmup: int = 10,
    iters: int = 50,
    burn: int = 5,
) -> float:
    import torch

    device = next(model.parameters()).device
    x = torch.randn(batch_size, 3, img_size, img_size, device=device)
    model.eval()
    with torch.no_grad():
        for _ in range(warmup):
            _ = model(x)
        timings = []
        for _ in range(iters):
            start = time.time()
            _ = model(x)
            if device.type == "cuda":
                torch.cuda.synchronize()
            end = time.time()
            timings.append((end - start) * 1000.0)
    kept = timings[burn:] if burn < len(timings) else timings
    return float(sum(kept) / max(1, len(kept)))


def _serialize_arch(arch: Dict) -> Dict:
    return {
        "hidden_dim": int(arch["hidden_dim"]),
        "depth": int(arch["depth"]),
        "num_heads": list(map(int, arch["num_heads"])),
        "mlp_ratio": [float(v) for v in arch["mlp_ratio"]],
    }


def _clean_float(value: Optional[float]) -> Optional[float]:
    if value is None:
        return None
    if isinstance(value, float) and (math.isnan(value) or math.isinf(value)):
        return None
    return float(value)


def run_search(settings: SearchSettings, llm_cfg: LLMConfig) -> SearchOutputs:
    random.seed(settings.seed)
    try:
        import torch as torch_module
    except ModuleNotFoundError:
        torch_module = None

    if torch_module is not None:
        torch_module.manual_seed(settings.seed)
    if not settings.use_cuda:
        os.environ["CUDA_VISIBLE_DEVICES"] = ""
    _configure_threads(settings.cpu_threads, torch_module=torch_module)

    repo_root = os.path.abspath(settings.repo_root)
    ap_root = ensure_autoprox_on_path(repo_root)

    metrics_repo: Optional[PrecomputedMetricsRepository] = None
    if settings.metrics_path:
        try:
            metrics_repo = PrecomputedMetricsRepository(settings.metrics_path, strict=False)
        except Exception:
            metrics_repo = None

    acc_min = settings.acc_min
    acc_max = settings.acc_max
    if metrics_repo:
        bounds = metrics_repo.get_accuracy_bounds(settings.supernet)
        if bounds:
            if acc_min is None:
                acc_min = bounds[0]
            if acc_max is None:
                acc_max = bounds[1]

    llm_cfg = _prepare_llm_config(llm_cfg)

    categories = default_categories(settings.supernet)

    os.makedirs(settings.output_root, exist_ok=True)

    num_classes = 100 if settings.dataset.lower() == "cifar100" else 10
    img_size = 32
    dataloader = None

    results: List[Dict] = []

    for iteration in range(settings.iterations):
        for cat_name, bounds in categories.items():
            pool: List[Tuple[Dict, str]] = []
            if llm_cfg.enabled and llm_cfg.per_category > 0:
                tuple_bounds = (
                    bounds.embed_dim,
                    bounds.depth,
                    bounds.num_heads,
                    bounds.mlp_ratio,
                )
                pool.extend(
                    (
                        arch,
                        "llm",
                    )
                    for arch in llm_generate_arches(
                        settings.supernet,
                        cat_name,
                        tuple_bounds,
                        llm_cfg.per_category,
                        llm_cfg,
                    )
                )
            while len(pool) < settings.per_category:
                pool.append((sample_arch_from_category(settings.supernet, bounds), "random"))

            for arch, origin in pool:
                if metrics_repo and metrics_repo.has_metrics(settings.supernet, arch):
                    metrics = metrics_repo.get_metrics(settings.supernet, arch)
                    if metrics is not None:
                        record = {
                            "iteration": iteration,
                            "category": cat_name,
                            "arch": _serialize_arch(arch),
                            "latency_ms": metrics.latency_ms,
                            "measures_raw": {},
                            "measure": settings.primary_measure,
                            "acc_proxy_norm": metrics.accuracy_norm,
                            "source": "precomputed",
                            "origin": origin,
                        }
                        record["measures_oriented"] = {}
                        results.append(record)
                        continue
                try:
                    if dataloader is None:
                        if torch_module is None:
                            raise RuntimeError(
                                "PyTorch is required for zero-cost evaluation. Install torch or provide a metrics dataset."
                            )
                        from torch.utils.data import DataLoader

                        from .datasets import RandomImageDataset

                        dataset = RandomImageDataset(settings.zc_samples, img_size, num_classes)
                        dataloader = DataLoader(dataset, batch_size=settings.zc_batch, shuffle=False)
                    model = build_autoformer_model(
                        ap_root,
                        arch,
                        dataset=settings.dataset,
                        cfg_override=settings.cfg_override,
                    )
                    latency_ms = measure_latency_ms(
                        model,
                        img_size=img_size,
                        batch_size=1,
                        warmup=settings.lat_warmup,
                        iters=settings.lat_iters,
                        burn=settings.lat_burn,
                    )
                    scores = compute_zero_cost_scores(
                        model,
                        dataloader,
                        num_classes=num_classes,
                        measures=settings.measures,
                    )
                    record = {
                        "iteration": iteration,
                        "category": cat_name,
                        "arch": _serialize_arch(arch),
                        "latency_ms": latency_ms,
                        "measures_raw": scores,
                        "measure": settings.primary_measure,
                        "source": "zero_cost",
                        "origin": origin,
                    }
                    results.append(record)
                except Exception as exc:
                    results.append(
                        {
                            "iteration": iteration,
                            "category": cat_name,
                            "arch": _serialize_arch(arch),
                            "error": str(exc),
                            "measure": settings.primary_measure,
                            "source": "error",
                            "origin": origin,
                        }
                    )

    valid = [r for r in results if "latency_ms" in r]
    oriented_records: List[Dict[str, Optional[float]]] = []
    for record in valid:
        oriented = {}
        for name, value in (record.get("measures_raw") or {}).items():
            oriented[name] = orient_measure(name, value)
        record["measures_oriented"] = oriented
        oriented_records.append(oriented)

    stats = build_measure_stats(oriented_records)
    for record in valid:
        agg = aggregate_scores(record["measures_oriented"], stats)
        if agg is not None:
            record["acc_proxy_norm"] = agg

    def map_accuracy_norm(norm_val: Optional[float]) -> Optional[float]:
        if norm_val is None:
            return None
        if acc_min is None or acc_max is None or acc_max <= acc_min:
            return None
        span = acc_max - acc_min
        est = acc_min + norm_val * span
        if settings.free_y:
            return est
        margin = max(0.0, settings.acc_margin)
        lo_bound = acc_min + margin
        hi_bound = acc_max - margin
        if hi_bound <= lo_bound:
            hi_bound = acc_max
            lo_bound = acc_min
        return max(min(est, hi_bound), lo_bound)

    use_accuracy = acc_min is not None and acc_max is not None and acc_max > acc_min
    if use_accuracy:
        for record in valid:
            record["est_acc"] = map_accuracy_norm(record.get("acc_proxy_norm"))

    y_key = "est_acc" if use_accuracy else "acc_proxy_norm"

    filtered = [r for r in valid if _clean_float(r.get(y_key)) is not None]
    front = pareto_front(filtered, acc_key=y_key, lat_key="latency_ms")

    timestamp = time.strftime("%Y%m%d-%H%M%S")
    output_dir = os.path.join(settings.output_root, f"run_{timestamp}")
    os.makedirs(output_dir, exist_ok=True)

    json_ready = []
    for record in valid:
        entry = dict(record)
        entry[y_key] = _clean_float(entry.get(y_key))
        entry["acc_proxy_norm"] = _clean_float(entry.get("acc_proxy_norm"))
        entry["latency_ms"] = _clean_float(entry.get("latency_ms"))
        json_ready.append(entry)

    summary_payload = {"all": json_ready, "pareto": front, "y_key": y_key}
    if use_accuracy:
        summary_payload["accuracy_bounds"] = [acc_min, acc_max]
    with open(os.path.join(output_dir, "results.json"), "w") as handle:
        json.dump(summary_payload, handle, indent=2)

    with open(os.path.join(output_dir, "results.csv"), "w", newline="") as csvfile:
        fieldnames = [
            "iteration",
            "category",
            "latency_ms",
            "acc_proxy_norm",
            "est_acc",
            "measure",
            "arch",
        ]
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()
        for record in valid:
            row = {
                "iteration": record.get("iteration"),
                "category": record.get("category"),
                "latency_ms": _clean_float(record.get("latency_ms")),
                "acc_proxy_norm": _clean_float(record.get("acc_proxy_norm")),
                "est_acc": _clean_float(record.get("est_acc")),
                "measure": record.get("measure"),
                "arch": json.dumps(record.get("arch")),
            }
            writer.writerow(row)

    y_label = "Accuracy (%)" if use_accuracy else "Auto-Prox Proxy (normalized)"
    y_limits = None
    if use_accuracy and not settings.free_y and acc_min is not None and acc_max is not None:
        lo = acc_min + settings.acc_margin
        hi = acc_max - settings.acc_margin
        if hi > lo:
            y_limits = (lo, hi)
    background_points = None
    if metrics_repo:
        entries = metrics_repo.iter_metrics(settings.supernet)
        if entries:
            background_points = []
            for record in entries:
                y_val = record.accuracy if use_accuracy else record.accuracy_norm
                background_points.append({"latency_ms": record.latency_ms, "y": y_val})

    save_pareto_plot(
        filtered,
        front,
        os.path.join(output_dir, "pareto.png"),
        acc_key=y_key,
        lat_key="latency_ms",
        title=f"ViT NAS Pareto ({', '.join(settings.measures)} vs latency)",
        y_limits=y_limits,
        y_label=y_label,
        background_points=background_points,
    )

    return SearchOutputs(output_dir=output_dir, results=valid, pareto=front, y_key=y_key)
