"""Utilities for building AutoFormer subnets via the Auto-Prox submodule."""

from __future__ import annotations

import os
import sys
from typing import Any, Dict, Optional

AUTO_PROX_DIR = "Auto-Prox-AAAI24"
DEFAULT_CFG_BY_DATASET = {
    "cifar10": "configs/auto/autoformer/autoformer-ti-subnet_c100_base.yaml",
    "cifar100": "configs/auto/autoformer/autoformer-ti-subnet_c100_base.yaml",
}


class AutoProxNotFoundError(FileNotFoundError):
    """Raised when the Auto-Prox submodule is missing."""


def ensure_autoprox_on_path(repo_root: str) -> str:
    """Inject Auto-Prox into ``sys.path`` and return its absolute path."""
    ap_root = os.path.join(repo_root, AUTO_PROX_DIR)
    if not os.path.isdir(ap_root):
        raise AutoProxNotFoundError(
            f"Auto-Prox-AAAI24 not found at {ap_root}. Run `git submodule update --init --recursive`."
        )
    if ap_root not in sys.path:
        sys.path.insert(0, ap_root)
    return ap_root


def resolve_base_config(ap_root: str, dataset: str, override_cfg: Optional[str] = None) -> str:
    """Resolve the AutoFormer YAML config to load.

    Parameters
    ----------
    ap_root: str
        Absolute path to the Auto-Prox submodule.
    dataset: str
        Dataset key (lower case).
    override_cfg: Optional[str]
        Explicit relative path under Auto-Prox to use instead of defaults.
    """
    if override_cfg:
        cfg_path = override_cfg
    else:
        cfg_path = DEFAULT_CFG_BY_DATASET.get(dataset.lower())
    if not cfg_path:
        raise ValueError(f"No default AutoFormer config registered for dataset '{dataset}'.")
    abs_path = cfg_path if os.path.isabs(cfg_path) else os.path.join(ap_root, cfg_path)
    if not os.path.isfile(abs_path):
        raise FileNotFoundError(f"AutoFormer base config not found: {abs_path}")
    return abs_path


def make_arch_config(
    hidden_dim: int,
    depth: int,
    num_heads: int,
    mlp_ratio: float,
    qkv_dim: Optional[int] = None,
) -> Dict:
    """Create the arch_config dict expected by ``AutoFormerSub``."""

    config = {
        "hidden_dim": int(hidden_dim),
        "depth": int(depth),
        "num_heads": [int(num_heads)] * int(depth),
        "mlp_ratio": [float(mlp_ratio)] * int(depth),
    }
    if qkv_dim is not None:
        config["qkv_dim"] = int(qkv_dim)
    return config


def build_autoformer_model(
    ap_root: str,
    arch: Dict,
    dataset: str = "cifar100",
    cfg_override: Optional[str] = None,
) -> Any:
    """Instantiate an ``AutoFormerSub`` model for ``arch`` and dataset."""

    # Deferred imports so Auto-Prox modules are available only after path injection.
    import torch
    from pycls.core.config import cfg as vit_cfg
    import pycls.core.config as vit_config
    from pycls.models.build import MODEL

    cfg_path = resolve_base_config(ap_root, dataset, cfg_override)
    vit_config.load_cfg(cfg_path)

    lower_dataset = dataset.lower()
    if lower_dataset == "cifar100":
        vit_cfg.MODEL.NUM_CLASSES = 100
        vit_cfg.MODEL.IMG_SIZE = 32
    elif lower_dataset == "cifar10":
        vit_cfg.MODEL.NUM_CLASSES = 10
        vit_cfg.MODEL.IMG_SIZE = 32
    # Other datasets keep YAML defaults.

    ctor = MODEL.get("AutoFormerSub")
    model = ctor(arch_config=arch, num_classes=vit_cfg.MODEL.NUM_CLASSES)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    return model.to(device)
