from typing import Optional, Any, Dict, Union, List

import torch
from torch import nn

from evaluate.benchmark_constants import WIKITEXT2, ALL, ALL_LM_EVAL_TASKS
from evaluate.language.wikitext2_eval import evaluate_language_model_wikitext2
from evaluate.language.lm_eval_harness import (
    evaluate_with_lm_eval_harness,
    LMEvalHarnessConfig,
)
from project_utils.timers import SegmentTimer


def _evaluate_model(
    model: nn.Module,
    tokenizer: Optional[Any],
    device: torch.device,
    batch_size: int = 200,
    max_samples: Optional[int] = None,
) -> Dict[str, Any]:
    """
    Evaluate a model based on its type.

    Args:
        model: Model to evaluate
        tokenizer: Tokenizer for language models (None for vision models)
        device: Device to run evaluation on
        batch_size: Batch size
        max_samples: Maximum number of samples to evaluate (None = all, language models)

    Returns:
        Evaluation results dictionary
    """
    print(f"\n{'='*80}")
    print(f"Running Evaluation")
    print(f"{'='*80}")

    timer = SegmentTimer()

    results = evaluate_language_model_wikitext2(
        model=model,
        tokenizer=tokenizer,
        batch_size=batch_size,
        max_samples=max_samples,
        device=device,
    )

    results['evaluation_time_seconds'] = timer.segment("evaluation", print_time=False)

    return results



def _normalize_tasks(tasks: Optional[Union[str, List[str]]]) -> List[str]:
    if tasks is None:
        return [WIKITEXT2]

    if isinstance(tasks, str):
        t = tasks.strip()
        if t.lower() == ALL:
            return ALL_LM_EVAL_TASKS
        return [t]

    out: List[str] = []
    for t in tasks:
        if t is None:
            continue
        out.append(str(t).strip())

    if any(t.lower() == ALL for t in out):
        return ALL_LM_EVAL_TASKS

    # de-dup preserving order (case-insensitive)
    seen = set()
    deduped: List[str] = []
    for t in out:
        key = t.lower()
        if key and key not in seen:
            seen.add(key)
            deduped.append(t)

    return deduped or [WIKITEXT2]


def evaluate_model(
        model: nn.Module,
        model_name: str,
        tokenizer: Optional[Any],
        device: torch.device,
        disable_thinking: bool,
        batch_size: int = 200,
        num_fewshot: int = 0,
        max_samples: Optional[int] = None,
        tasks: Optional[Union[str, List[str]]] = None,
) -> Dict[str, Any]:
    """
    Evaluate a model on selected tasks.

    Task names:
      - Built-in: `wikitext2`, `mmlu_cot`, `hellaswag`
      - Harness: prefix with `lm_eval:` (e.g. `lm_eval:mmlu`, `lm_eval:hellaswag`).
        You can also pass `lm_eval` as a meta-task to run a small default set.

    Returns:
        Results dictionary keyed by task name, plus total evaluation time.

    Notes on lm-eval-harness output:
      - Each harness task is returned as its own entry under key `lm_eval:<task>`.
      - The per-task payload is intentionally compact (wikitext2-like): a main
        metric (`acc_norm` preferred) plus a couple of small helpful fields.
      - The raw harness output is stored under `lm_eval_raw` for debugging.
    """
    print(f"\n{'='*80}")
    print("Running Evaluation")
    print(f"{'='*80}")

    model.eval()
    all_tasks = ALL_LM_EVAL_TASKS
    timer = SegmentTimer()
    task_list = _normalize_tasks(tasks)

    results: Dict[str, Any] = {}

    # Collect harness tasks separately so we only initialize harness once.
    harness_tasks: List[str] = []

    for task_raw in task_list:
        task = task_raw.strip()
        task_l = task.lower()

        if task_l == WIKITEXT2:
            out = evaluate_language_model_wikitext2(
                model=model,
                tokenizer=tokenizer,
                batch_size=batch_size,
                max_samples=max_samples,
                device=device,
            )
            results.update({WIKITEXT2: out['perplexity']})
        elif task_l in all_tasks:
            harness_tasks.append(task.strip())
        else:
            raise ValueError(
                f"Unknown task: {task}. Supported: {all_tasks}, all"
            )

    if harness_tasks:
        if tokenizer is None:
            raise ValueError("`tokenizer` is required for lm_eval_harness evaluation")

        harness_out = evaluate_with_lm_eval_harness(
            model=model,
            model_name=model_name,
            tokenizer=tokenizer,
            device_torch=device,
            disable_thinking=disable_thinking,
            cfg=LMEvalHarnessConfig(
                tasks=harness_tasks,
                num_fewshot=num_fewshot,
                limit=max_samples,
                batch_size=batch_size,
            ),
        )
        results.update(harness_out)

    results["evaluation_time_seconds"] = timer.segment("evaluation", print_time=False)
    return results