from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple

from constants import (
    MODEL_FAMILY_LLAMA, MODEL_FAMILY_QWEN,
    LMEVAL_ADD_BOS_TOKEN, LMEVAL_MAX_LENGTH, LMEVAL_TASKS,
    LMEVAL_APPLY_CHAT_TEMPLATE, LMEVAL_NUM_FEWSHOT, LMEVAL_FEWSHOT_AS_MULTITURN,
    LMEVAL_GEN_KWARGS, LMEVAL_MAX_GEN_TOKS
)
from evaluate.benchmark_constants import MMLU_COT, ARC_CHALLENGE, GSM8K, HELLASWAG, WINOGRANDE, TRUTHFULQA_MC2, \
    OPENBOOKQA, ARC_EASY, PIQA, GPQA, LOGIQA, SCIQ, BOOLQ, COMMONSENSE_QA, COPA, SOCIAL_IQA, LAMBADA


@dataclass(frozen=True)
class LMEvalRunSpec:
    # kwargs you pass to HFLM(...)
    hflm_kwargs: Dict[str, Any]
    # kwargs you pass to evaluator.simple_evaluate(...)
    simple_evaluate_kwargs: Dict[str, Any]


# --- Model profile inference and helpers ---


@dataclass(frozen=True)
class ModelProfile:
    family: str  # e.g. "llama", "qwen", "other"
    is_instruct: bool


def _infer_model_profile(model_name: str) -> ModelProfile:
    """Infer a minimal profile from the HF repo id / local name."""
    name = (model_name or "").lower()

    # Family detection
    if MODEL_FAMILY_LLAMA in name:
        family = MODEL_FAMILY_LLAMA
    elif MODEL_FAMILY_QWEN in name:
        family = MODEL_FAMILY_QWEN
    else:
        family = "other"

    # Instruct/chat detection (best-effort)
    is_instruct = any(tok in name for tok in ["instruct", "chat", "it"])

    return ModelProfile(family=family, is_instruct=is_instruct)


def _apply_chat_flags(spec: Dict[str, Any], profile: ModelProfile) -> Dict[str, Any]:
    """Ensure chat flags are consistent with the model profile."""
    out = dict(spec)

    if profile.is_instruct:
        out[LMEVAL_APPLY_CHAT_TEMPLATE] = True
        # Only meaningful when few-shot > 0
        num_fewshot = int(out.get(LMEVAL_NUM_FEWSHOT, 0) or 0)
        out[LMEVAL_FEWSHOT_AS_MULTITURN] = num_fewshot > 0
    else:
        out[LMEVAL_APPLY_CHAT_TEMPLATE] = False
        out[LMEVAL_FEWSHOT_AS_MULTITURN] = False

    return out


def _task_ids_for_family(task_key: str, family: str, task_eval_name:str) -> list[str]:
    """Return candidate lm-eval task IDs for a given logical task key."""
    # Llama-3.x instruct tuned variants (preferred for llama family)
    llama_map: Dict[str, list[str]] = {
        MMLU_COT: ["mmlu_cot_llama"],
        # ARC_CHALLENGE: ["arc_challenge_llama"],
        GSM8K: ["gsm8k_llama"],
    }

    # Qwen: prefer generic task IDs (chat formatting comes from chat template)
    qwen_map: Dict[str, list[str]] = {
        MMLU_COT: ["mmlu_flan_cot_zeroshot"],
    }

    other_map = qwen_map

    mapping = llama_map if family == MODEL_FAMILY_LLAMA else qwen_map if family == MODEL_FAMILY_QWEN else other_map
    return mapping.get(task_key, task_eval_name)


def _pick_task_id(candidates: list[str], available_task_ids: Optional[set[str]] = None) -> str:
    """Pick the first candidate present in available_task_ids if provided; else pick the first."""
    if not candidates:
        raise ValueError("No task candidates provided")

    if available_task_ids is None:
        return candidates[0]

    for t in candidates:
        if t in available_task_ids:
            return t

    # If nothing matched, return the first candidate (will error later in lm-eval, but preserves intent)
    return candidates[0]


# Specs copied from the model card “Reproduction” commands.
# Note: max_model_len is a vLLM concern; for HFLM/Transformers it is usually governed by
# tokenizer.model_max_length / model config, so it’s not included here.
_TASK_SPECS: Dict[str, LMEvalRunSpec] = {
    MMLU_COT: LMEvalRunSpec(
        hflm_kwargs={LMEVAL_ADD_BOS_TOKEN: True},
        simple_evaluate_kwargs={
            LMEVAL_TASKS: ["__AUTO__"],
            LMEVAL_APPLY_CHAT_TEMPLATE: True,
            LMEVAL_GEN_KWARGS: {LMEVAL_MAX_GEN_TOKS: 1024},
        },
    ),
    ARC_CHALLENGE: LMEvalRunSpec(
        hflm_kwargs={LMEVAL_ADD_BOS_TOKEN: True},
        simple_evaluate_kwargs={
            LMEVAL_TASKS: ["arc_challenge"],
            LMEVAL_APPLY_CHAT_TEMPLATE: True,
            LMEVAL_GEN_KWARGS: {LMEVAL_MAX_GEN_TOKS: 100},
        },
    ),
    GSM8K: LMEvalRunSpec(
        hflm_kwargs={LMEVAL_ADD_BOS_TOKEN: True},
        simple_evaluate_kwargs={
            LMEVAL_TASKS: ["__AUTO__"],
            LMEVAL_APPLY_CHAT_TEMPLATE: True,
            LMEVAL_GEN_KWARGS: {LMEVAL_MAX_GEN_TOKS: 1024},
        },
    ),
    HELLASWAG: LMEvalRunSpec(
        hflm_kwargs={LMEVAL_ADD_BOS_TOKEN: True},
        simple_evaluate_kwargs={
            LMEVAL_TASKS: ["hellaswag"],
            LMEVAL_NUM_FEWSHOT: 10,
            LMEVAL_APPLY_CHAT_TEMPLATE: False,
            LMEVAL_GEN_KWARGS: {},
        },
    ),
    WINOGRANDE: LMEvalRunSpec(
        hflm_kwargs={LMEVAL_ADD_BOS_TOKEN: True},
        simple_evaluate_kwargs={
            LMEVAL_TASKS: ["winogrande"],
            LMEVAL_NUM_FEWSHOT: 5,
            LMEVAL_APPLY_CHAT_TEMPLATE: False,
            LMEVAL_GEN_KWARGS: {},
        },
    ),
    TRUTHFULQA_MC2: LMEvalRunSpec(
        hflm_kwargs={LMEVAL_ADD_BOS_TOKEN: True},
        simple_evaluate_kwargs={
            LMEVAL_TASKS: ["truthfulqa_mc2"],
            LMEVAL_APPLY_CHAT_TEMPLATE: False,
            LMEVAL_GEN_KWARGS: {},
        },
    ),
    PIQA: LMEvalRunSpec(
        hflm_kwargs={LMEVAL_ADD_BOS_TOKEN: True},
        simple_evaluate_kwargs={
            LMEVAL_TASKS: ["piqa"],
            LMEVAL_APPLY_CHAT_TEMPLATE: False,
            LMEVAL_GEN_KWARGS: {},
        },
    ),
    OPENBOOKQA: LMEvalRunSpec(
        hflm_kwargs={LMEVAL_ADD_BOS_TOKEN: True},
        simple_evaluate_kwargs={
            LMEVAL_TASKS: ["openbookqa"],
            LMEVAL_APPLY_CHAT_TEMPLATE: False,
            LMEVAL_GEN_KWARGS: {},
        },
    ),
    ARC_EASY: LMEvalRunSpec(
        hflm_kwargs={LMEVAL_ADD_BOS_TOKEN: True},
        simple_evaluate_kwargs={
            LMEVAL_TASKS: ["arc_easy"],
            LMEVAL_APPLY_CHAT_TEMPLATE: False,
            LMEVAL_GEN_KWARGS: {},
        },
    ),
    BOOLQ: LMEvalRunSpec(
        hflm_kwargs={LMEVAL_ADD_BOS_TOKEN: True},
        simple_evaluate_kwargs={
            LMEVAL_TASKS: ["boolq"],
            LMEVAL_NUM_FEWSHOT: 0,  # Often reported 0-shot for reading comp
            LMEVAL_APPLY_CHAT_TEMPLATE: False,
            LMEVAL_GEN_KWARGS: {},
        },
    ),
    COMMONSENSE_QA: LMEvalRunSpec(
        hflm_kwargs={LMEVAL_ADD_BOS_TOKEN: True},
        simple_evaluate_kwargs={
            LMEVAL_TASKS: ["commonsense_qa"],
            LMEVAL_NUM_FEWSHOT: 0,  # Standardized 7-shot in many tech reports
            LMEVAL_APPLY_CHAT_TEMPLATE: False,
            LMEVAL_GEN_KWARGS: {},
        },
    ),
    COPA: LMEvalRunSpec(
        hflm_kwargs={LMEVAL_ADD_BOS_TOKEN: True},
        simple_evaluate_kwargs={
            LMEVAL_TASKS: ["copa"],
            LMEVAL_NUM_FEWSHOT: 0,
            LMEVAL_APPLY_CHAT_TEMPLATE: False,
            LMEVAL_GEN_KWARGS: {},
        },
    ),
    SOCIAL_IQA: LMEvalRunSpec(
        hflm_kwargs={LMEVAL_ADD_BOS_TOKEN: True},
        simple_evaluate_kwargs={
            LMEVAL_TASKS: ["social_iqa"],
            LMEVAL_NUM_FEWSHOT: 0,
            LMEVAL_APPLY_CHAT_TEMPLATE: False,
            LMEVAL_GEN_KWARGS: {},
        },
    ),
    SCIQ: LMEvalRunSpec(
        hflm_kwargs={LMEVAL_ADD_BOS_TOKEN: True},
        simple_evaluate_kwargs={
            LMEVAL_TASKS: ["sciq"],
            LMEVAL_NUM_FEWSHOT: 0,
            LMEVAL_APPLY_CHAT_TEMPLATE: False,
            LMEVAL_GEN_KWARGS: {},
        },
    ),
    LAMBADA: LMEvalRunSpec(
        hflm_kwargs={LMEVAL_ADD_BOS_TOKEN: True},
        simple_evaluate_kwargs={
            LMEVAL_TASKS: ["lambada_openai"], # Use the OpenAI version for research alignment
            LMEVAL_NUM_FEWSHOT: 0,
            LMEVAL_APPLY_CHAT_TEMPLATE: False,
            LMEVAL_GEN_KWARGS: {},
        },
    ),
    LOGIQA: LMEvalRunSpec(
        hflm_kwargs={LMEVAL_ADD_BOS_TOKEN: True},
        simple_evaluate_kwargs={
            LMEVAL_TASKS: ["logiqa"],
            LMEVAL_NUM_FEWSHOT: 0,
            LMEVAL_APPLY_CHAT_TEMPLATE: False,
            LMEVAL_GEN_KWARGS: {},
        },
    ),
    GPQA: LMEvalRunSpec(
        hflm_kwargs={LMEVAL_ADD_BOS_TOKEN: True},
        simple_evaluate_kwargs={
            LMEVAL_TASKS: ["gpqa_main_zeroshot"], # 'main' is the expert-verified subset
            LMEVAL_NUM_FEWSHOT: 0,
            LMEVAL_APPLY_CHAT_TEMPLATE: False,
            LMEVAL_GEN_KWARGS: {},
        },
    ),
}



def get_lm_eval_run_spec(
    task_name: str,
    model_name: Optional[str] = None,
    *,
    available_task_ids: Optional[set[str]] = None,
) -> LMEvalRunSpec:
    """Return a resolved run spec for a task, optionally conditioned on the model name.

    - `model_name`: HF repo id or any identifier used to infer family/instruct.
    - `available_task_ids`: optional set of task IDs present in the installed lm-eval;
      if provided, we pick the first matching candidate for the model family.
    """
    if task_name not in _TASK_SPECS:
        raise KeyError(f"Unknown task_name={task_name!r}. Known: {sorted(_TASK_SPECS.keys())}")

    base = _TASK_SPECS[task_name]
    task_eval_name = base.simple_evaluate_kwargs[LMEVAL_TASKS]
    profile = _infer_model_profile(model_name or "")

    # Resolve task id(s)
    candidates = _task_ids_for_family(task_name, profile.family, task_eval_name)
    task_id = _pick_task_id(candidates, available_task_ids=available_task_ids)

    simple_kwargs = dict(base.simple_evaluate_kwargs)
    simple_kwargs[LMEVAL_TASKS] = [task_id]
    simple_kwargs = _apply_chat_flags(simple_kwargs, profile)

    # Resolve HFLM kwargs
    hflm_kwargs = dict(base.hflm_kwargs)

    # Llama family commonly needs BOS matching with some harness configs; Qwen generally does not.
    if profile.family == MODEL_FAMILY_LLAMA:
        hflm_kwargs.setdefault(LMEVAL_ADD_BOS_TOKEN, True)

    hflm_kwargs.setdefault(LMEVAL_MAX_LENGTH, 4096)
    return LMEvalRunSpec(hflm_kwargs=hflm_kwargs, simple_evaluate_kwargs=simple_kwargs)
