from __future__ import annotations

from dataclasses import dataclass
from functools import partial
from typing import Any, Dict, Optional, List

import torch
from torch import nn

from constants import MODEL_FAMILY_QWEN
from evaluate.language.get_task_params import get_lm_eval_run_spec
from evaluate.benchmark_constants import main_benchmark


@dataclass(frozen=True, slots=True)
class LMEvalHarnessConfig:
    """Configuration for EleutherAI lm-evaluation-harness integration."""

    tasks: List[str]
    num_fewshot: int = 0
    limit: Optional[int] = None
    batch_size: Optional[int] = None
    max_batch_size: Optional[int] = None
    device: Optional[str] = None  # e.g. "cuda", "cpu", "mps"; if None, inferred from `device_torch`


def extract_accuracy_like_metrics(lm_eval_out: Dict[str, Any], task: str) -> Dict[str, float]:
    results = lm_eval_out.get("results") or {}
    if not isinstance(results, dict):
        return {}

    out: Dict[str, float] = {}
    for task_full_name, metrics in results.items():
        if not isinstance(metrics, dict):
            continue
        for k, v in metrics.items():
            if isinstance(v, float):
                if k in main_benchmark[task]:
                    out['main_result/' + task] = float(v)
                out[str(task) + "_" + k.replace(",", "_")] = float(v)
    return out


def evaluate_with_lm_eval_harness(
        model: nn.Module,
        model_name: str,
        tokenizer: Any,
        device_torch: torch.device,
        cfg: LMEvalHarnessConfig,
        disable_thinking: bool,
) -> Dict[str, Any]:
    """Run EleutherAI `lm-evaluation-harness` tasks on an in-memory model.

    Uses the official `HFLM` adapter, passing your already-loaded model + tokenizer
    directly to it.

    Return shape:
      - `lm_eval_out`: raw output from `evaluator.simple_evaluate`
      - `accuracy_like`: extracted per-task metric (acc/acc_norm/exact_match/...)
    """

    try:
        from lm_eval import evaluator
        from lm_eval.models.huggingface import HFLM
    except Exception as e:  # pragma: no cover
        raise RuntimeError(
            "`lm-evaluation-harness` (package: lm_eval) is not installed. "
            "Install it, e.g. `pip install lm-eval`."
        ) from e

    results: Dict[str, Any] = {}
    for task in cfg.tasks:
        spec = get_lm_eval_run_spec(task, model_name)

        if disable_thinking and MODEL_FAMILY_QWEN in model_name:
            tokenizer.apply_chat_template = partial(
                    tokenizer.apply_chat_template,
                    enable_thinking=False
                )
        lm = HFLM(
            pretrained=model,  # type: ignore[arg-type]
            tokenizer=tokenizer,
            device=str(device_torch),
            batch_size=cfg.batch_size,
            max_batch_size=cfg.max_batch_size,
            **spec.hflm_kwargs
        )


        out = evaluator.simple_evaluate(
            model=lm,
            batch_size='auto',
            limit=cfg.limit,
            **spec.simple_evaluate_kwargs
        )
        results.update(extract_accuracy_like_metrics(out, task))
    return results
