"""Command-line interface for ViT AutoFormer NAS experiments."""

from __future__ import annotations

import argparse
import os
import time
from typing import Dict, List, Optional, Tuple

from .llm import LLMConfig
from .plot_supernet_comparison import generate_supernet_comparison
from .runner import SearchOutputs, SearchSettings, run_search


def _parse_measures(measures: str, fallback: str) -> Tuple[str, ...]:
    if measures:
        cleaned = [m.strip() for m in measures.split(",") if m.strip()]
        if cleaned:
            return tuple(cleaned)
    return (fallback,)


def _parse_supernets(arg: str) -> List[str]:
    if not arg:
        return ["tiny"]
    items = [item.strip().lower() for item in arg.split(",") if item.strip()]
    if not items:
        return ["tiny"]
    if len(items) == 1 and items[0] == "all":
        return ["tiny", "small", "base"]
    valid = {"tiny", "small", "base"}
    ordered: List[str] = []
    for item in items:
        if item not in valid:
            raise ValueError(f"Unsupported supernet '{item}'. Choose tiny, small, base, or 'all'.")
        if item not in ordered:
            ordered.append(item)
    return ordered


def build_arg_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="ViT NAS experiment using Auto-Prox zero-cost proxies.")
    parser.add_argument("--dataset", default="cifar100", choices=["cifar10", "cifar100"], help="target dataset")
    parser.add_argument(
        "--supernet",
        default="base",
        help="AutoFormer supernet to search (tiny/small/base, comma-separated, or 'all' for all three)",
    )
    parser.add_argument("--per-cat", type=int, default=12, help="architectures sampled per category per iteration")
    parser.add_argument("--iterations", type=int, default=5, help="number of search iterations")
    parser.add_argument("--measure", default="jacov", help="primary zero-cost measure name")
    parser.add_argument("--measures", default="jacov,jacobian_trace", help="comma-separated list of zero-cost measures")
    parser.add_argument("--free-y", action="store_true", help="do not clamp estimated accuracy to the provided range")
    parser.add_argument("--output", default="outputs/vit_autoformer_nas", help="output directory root")
    parser.add_argument("--comparison-output", default=None, help="optional path for combined comparison plot when multiple supernets are searched")
    parser.add_argument("--seed", type=int, default=42, help="base random seed")
    parser.add_argument("--no-cuda", action="store_true", help="force CPU execution")
    parser.add_argument("--acc-min", type=float, default=None, help="lower bound for mapping proxy to accuracy")
    parser.add_argument("--acc-max", type=float, default=None, help="upper bound for mapping proxy to accuracy")
    parser.add_argument("--acc-y-margin", type=float, default=0.8, help="margin when clamping estimated accuracy")
    parser.add_argument("--lat-burn", type=int, default=8, help="measured latency iterations to drop after warmup")
    parser.add_argument("--lat-warmup", type=int, default=8, help="number of warmup forwards")
    parser.add_argument("--lat-iters", type=int, default=40, help="number of measured forwards for latency")
    parser.add_argument("--zc-samples", type=int, default=64, help="random samples for zero-cost dataloader")
    parser.add_argument("--zc-batch", type=int, default=32, help="batch size for zero-cost dataloader")
    parser.add_argument("--cpu-threads", type=int, default=None, help="limit CPU threads used by PyTorch/BLAS")
    parser.add_argument("--llm-per-cat", type=int, default=3, help="LLM-generated candidates per category per iteration")
    parser.add_argument("--llm-model", type=str, default="gpt-4.1", help="override LLM model identifier")
    parser.add_argument("--llm-api-key", type=str, default=None, help="explicit OpenAI-compatible API key")
    parser.add_argument("--llm-temperature", type=float, default=0.2, help="LLM sampling temperature")
    parser.add_argument("--llm-max-tokens", type=int, default=800, help="LLM response token limit")
    parser.add_argument("--supernet-config", type=str, default=None, help="override AutoFormer YAML config (relative to Auto-Prox root or absolute path)")
    parser.add_argument(
        "--metrics-path",
        type=str,
        default="vit_autoformer_nas/data/autoformer_metrics.json",
        help="precomputed metrics JSON (empty string to disable)",
    )
    return parser


def main(argv=None) -> Optional[SearchOutputs]:
    parser = build_arg_parser()
    args = parser.parse_args(argv)

    measures = _parse_measures(args.measures, args.measure)
    supernets = _parse_supernets(args.supernet)

    llm_cfg = LLMConfig(
        per_category=args.llm_per_cat,
        model=args.llm_model or "gpt-4o-mini",
        api_key=args.llm_api_key or os.getenv("OPENAI_API_KEY", ""),
        temperature=args.llm_temperature,
        max_tokens=args.llm_max_tokens,
    )

    metrics_path = args.metrics_path.strip() if args.metrics_path is not None else None
    if metrics_path == "":
        metrics_path = None

    settings_template = dict(
        repo_root=os.getcwd(),
        dataset=args.dataset,
        per_category=args.per_cat,
        iterations=args.iterations,
        measures=measures,
        primary_measure=args.measure,
        use_cuda=not args.no_cuda,
        cpu_threads=args.cpu_threads,
        lat_warmup=args.lat_warmup,
        lat_iters=args.lat_iters,
        lat_burn=args.lat_burn,
        zc_samples=args.zc_samples,
        zc_batch=args.zc_batch,
        output_root=args.output,
        acc_min=args.acc_min,
        acc_max=args.acc_max,
        acc_margin=args.acc_y_margin,
        free_y=args.free_y,
        cfg_override=args.supernet_config,
        metrics_path=metrics_path,
    )

    results_paths: Dict[str, str] = {}
    last_output: Optional[SearchOutputs] = None

    for idx, supernet in enumerate(supernets):
        run_seed = args.seed + idx
        settings_kwargs = dict(settings_template)
        settings_kwargs["supernet"] = supernet
        settings_kwargs["seed"] = run_seed
        settings = SearchSettings(**settings_kwargs)

        last_output = run_search(settings, llm_cfg)
        results_paths[supernet] = os.path.join(last_output.output_dir, "results.json")
        print(f"[{supernet}] y-axis: {last_output.y_key}")
        print(f"[{supernet}] results saved to {last_output.output_dir}")

    if len(supernets) > 1 and results_paths:
        timestamp = time.strftime("%Y%m%d-%H%M%S")
        comparison_output = args.comparison_output or os.path.join(
            args.output,
            f"supernet_comparison_{timestamp}.png",
        )
        metrics_json = metrics_path or "vit_autoformer_nas/data/autoformer_metrics.json"
        generate_supernet_comparison(
            metrics_json=metrics_json,
            supernet_results=results_paths,
            output_path=comparison_output,
        )
        print(f"Combined comparison plot saved to {comparison_output}")

    return last_output


if __name__ == "__main__":
    main()

