"""Evaluate LongBench-style results produced by this project's benchmark runner.

Usage examples:
  python3 evaluate.py \
    --path results/llama-3.1/narrativeqa/my_exp/ \
    --dataset narrativeqa

  python3 evaluate.py \
    --path results/llama-3.1/narrativeqa/my_exp/baseline.jsonl \
    --dataset narrativeqa

Input format (each line in the .jsonl file):
  {
    "pred": str,                 # model's generated answer
    "answers": [str, ...] | null # ground-truth answers (list or single string)
    ...                          # other fields ignored
  }

This script mirrors the logic of the root-level evaluate_longbench.py but adapts to the
field names used by long_context_eval's BenchmarkEvaluator outputs (pred/answers).

If the repository-level metrics.py is available, dataset-specific metrics will be used.
Otherwise, the script falls back to a simple QA F1 metric (token-overlap F1).

Time complexity: O(N * A * L) where N is number of examples, A is number of ground
truth answers per example, and L is the number of tokens in the strings.
Space complexity: O(1) auxiliary besides input lines processed incrementally.
"""

from __future__ import annotations

import argparse
import json
import os
import string
from typing import Callable, Dict, List, Optional


from src.evaluation.metrics import (
    qa_f1_score,
    rouge_zh_score,
    qa_f1_zh_score,
    rouge_score,
    classification_score,
    retrieval_score,
    retrieval_zh_score,
    count_score,
    code_sim_score,
)


def _normalize(text: str) -> List[str]:
    """Lowercase, strip punctuation, and split on whitespace into tokens."""
    text = text.lower()
    text = text.translate(str.maketrans("", "", string.punctuation))
    tokens = text.split()
    return tokens


def _qa_f1_fallback(prediction: str, ground_truth: str) -> float:
    """Compute token-level F1 between prediction and a single ground truth.

    This simple implementation is language-agnostic and dependency-free.
    """
    pred_tokens = _normalize(prediction)
    gold_tokens = _normalize(ground_truth)
    if len(pred_tokens) == 0 and len(gold_tokens) == 0:
        return 1.0
    if len(pred_tokens) == 0 or len(gold_tokens) == 0:
        return 0.0

    # Count intersection
    from collections import Counter

    pred_counts = Counter(pred_tokens)
    gold_counts = Counter(gold_tokens)
    common = sum((pred_counts & gold_counts).values())
    if common == 0:
        return 0.0

    precision = common / max(1, sum(pred_counts.values()))
    recall = common / max(1, sum(gold_counts.values()))
    if precision + recall == 0:
        return 0.0
    return 2 * precision * recall / (precision + recall)


def _get_dataset2metric(cfg_dir: str) -> Dict[str, Callable[[str, str], float]]:
    """Return dataset-to-metric mapping loaded from configs/dataset2metric.json."""
    import json
    mapping_path = os.path.join(cfg_dir, "dataset2metric.json")
    with open(mapping_path, "r", encoding="utf-8") as f:
        name_map: Dict[str, str] = json.load(f)

    # Map metric name strings to callables
    name_to_fn: Dict[str, Callable[[str, str], float]] = {
        "qa_f1_score": qa_f1_score,
        "rouge_zh_score": rouge_zh_score,
        "qa_f1_zh_score": qa_f1_zh_score,
        "rouge_score": rouge_score,
        "classification_score": classification_score,
        "retrieval_score": retrieval_score,
        "retrieval_zh_score": retrieval_zh_score,
        "count_score": count_score,
        "code_sim_score": code_sim_score,
    }

    out: Dict[str, Callable[[str, str], float]] = {}
    for dataset, metric_name in name_map.items():
        fn = name_to_fn.get(metric_name)
        if fn is not None:
            out[dataset] = fn
    return out


def _coerce_answers(raw: Optional[object]) -> Optional[List[str]]:
    """Coerce the stored answers field to a list of strings or None."""
    if raw is None:
        return None
    if isinstance(raw, str):
        return [raw]
    if isinstance(raw, list):
        # Keep only string-like answers
        return [str(x) for x in raw]
    # Unknown type; ignore
    return None


def evaluate_file(file_path: str, dataset: str) -> float:
    """Evaluate a single results .jsonl file and return the final score in [0, 100]."""
    base_dir = os.path.dirname(os.path.abspath(__file__))
    cfg_dir = os.path.join(base_dir, "configs")
    dataset2metric = _get_dataset2metric(cfg_dir)
    metric_fn: Callable[[str, str], float] = dataset2metric.get(dataset, qa_f1_score)

    # Datasets whose predictions should be truncated to the first line
    first_line_only = {"trec", "triviaqa", "samsum", "lsht"}

    num_examples = 0
    total_score = 0.0

    with open(file_path, "r", encoding="utf-8") as f:
        for line in f:
            if not line.strip():
                continue
            num_examples += 1
            record = json.loads(line)
            pred: str = str(record.get("pred", ""))
            if dataset in first_line_only:
                pred = pred.lstrip("\n").split("\n")[0]

            answers = _coerce_answers(record.get("answers"))
            if not answers:
                # preserve denominator, like the reference script
                continue

            # Score is the max over all ground-truth variants
            best = 0.0
            for gold in answers:
                # Pass through kwargs if present
                kwargs = {}
                if "all_classes" in record and record["all_classes"] is not None:
                    kwargs["all_classes"] = record["all_classes"]
                score = float(metric_fn(pred, gold, **kwargs))
                if score > best:
                    best = score
            total_score += best

    if num_examples == 0:
        return 0.0
    return round(100.0 * total_score / num_examples, 2)


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Evaluate results produced by long_context_eval runner.")
    parser.add_argument(
        "--path",
        type=str,
        required=True,
        help="Path to a results .jsonl file (file only, no directories)",
    )
    parser.add_argument(
        "--dataset",
        type=str,
        required=True,
        help="LongBench subset name (e.g., narrativeqa, qmsum, trec)",
    )
    return parser.parse_args()


def main() -> None:
    args = parse_args()
    if not os.path.isfile(args.path):
        print(f"Path is not a file: {args.path}")
        return
    score = evaluate_file(args.path, args.dataset)
    print(f"Evaluated: {args.path}")
    print(f"Dataset: {args.dataset}")
    print(f"Final Score: {score}")


if __name__ == "__main__":
    main()


