#!/usr/bin/env python3
from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Any, Iterable

from bd_mcts.tasks.math import (
    FastMathTask,
    _load_hendrycks_math_samples,
    extract_final_answer,
)


def _iter_response_objects(path: Path) -> Iterable[Any]:
    if path.suffix == ".json":
        payload = json.loads(path.read_text(encoding="utf-8"))
        if isinstance(payload, dict):
            for key, value in payload.items():
                if isinstance(value, dict):
                    obj = dict(value)
                    obj.setdefault("sample_id", key)
                    yield obj
                else:
                    yield {"sample_id": key, "response": value}
            return
        if isinstance(payload, list):
            for item in payload:
                yield item
            return
        raise ValueError("JSON payload must be a list or dict.")

    with path.open("r", encoding="utf-8") as handle:
        for line in handle:
            line = line.strip()
            if not line:
                continue
            try:
                yield json.loads(line)
            except json.JSONDecodeError:
                yield line


def _get_response(obj: Any, response_field: str, line_no: int) -> str:
    if isinstance(obj, str):
        return obj
    if isinstance(obj, dict):
        for key in (response_field, "response", "answer", "prediction", "output"):
            if key in obj:
                return str(obj[key])
        raise ValueError(f"Missing response field at entry {line_no}.")
    raise ValueError(f"Unsupported response entry at {line_no}: {type(obj)}")


def _coerce_sample_id(value: Any, fallback: int, line_no: int) -> int:
    if value is None:
        return fallback
    try:
        return int(value)
    except (TypeError, ValueError) as exc:
        raise ValueError(f"Invalid sample_id at entry {line_no}: {value}") from exc


def _load_responses(
    path: Path, response_field: str, id_field: str
) -> dict[int, str]:
    responses: dict[int, str] = {}
    fallback_id = 0
    for line_no, obj in enumerate(_iter_response_objects(path), start=1):
        response = _get_response(obj, response_field, line_no)
        sample_id = None
        if isinstance(obj, dict):
            sample_id = obj.get(id_field)
        resolved_id = _coerce_sample_id(sample_id, fallback_id, line_no)
        responses[resolved_id] = response
        if sample_id is None:
            fallback_id += 1
    return responses


def _load_samples(
    dataset: str, split: str, config: str | None, limit: int | None
) -> list[dict[str, Any]]:
    samples = _load_hendrycks_math_samples(
        dataset=dataset,
        split=split,
        dataset_config=config,
        per_category_limit=None,
        per_category_shuffle=False,
        per_category_seed=None,
    )
    if limit is not None:
        if limit < 1:
            return []
        limit = min(limit, len(samples))
        return samples[:limit]
    return samples


def main() -> int:
    parser = argparse.ArgumentParser(
        description="Evaluate Hendrycks MATH responses with FastMathTask."
    )
    parser.add_argument("--dataset", default="hendrycks/math")
    parser.add_argument("--split", default="test")
    parser.add_argument("--dataset-config", default=None)
    parser.add_argument("--responses", required=True)
    parser.add_argument("--response-field", default="response")
    parser.add_argument("--id-field", default="sample_id")
    parser.add_argument("--question-key", default="problem")
    parser.add_argument("--answer-key", default="solution")
    parser.add_argument("--limit", type=int, default=None)
    parser.add_argument("--no-math-verify", action="store_true")
    parser.add_argument("--prm", action="store_true")
    parser.add_argument("--output", default=None)
    args = parser.parse_args()

    samples = _load_samples(
        args.dataset, args.split, args.dataset_config, args.limit
    )
    if not samples:
        raise SystemExit("No samples loaded.")

    task = FastMathTask(
        samples,
        question_key=args.question_key,
        answer_key=args.answer_key,
        use_math_verify=not args.no_math_verify,
        parse_mode="response",
    )

    response_path = Path(args.responses)
    responses = _load_responses(response_path, args.response_field, args.id_field)
    if not responses:
        raise SystemExit("No responses loaded.")

    results: list[dict[str, Any]] = []
    for sample_id, response in sorted(responses.items()):
        if sample_id < 0 or sample_id >= len(samples):
            raise SystemExit(
                f"sample_id {sample_id} out of range for {len(samples)} samples."
            )
        clean_response = task.parse_answer(sample_id, response)
        parsed_answer = extract_final_answer(clean_response)

        submit_res = task.submit(sample_id, clean_response)
        entry: dict[str, Any] = {
            "sample_id": sample_id,
            "parsed_answer": parsed_answer,
            "submit_metric": submit_res.metric,
            "submit_detail": submit_res.sample_detail,
        }
        if args.prm:
            eval_res = task.evaluate(sample_id, clean_response)
            entry["prm_metric"] = eval_res.metric
            entry["prm_detail"] = eval_res.sample_detail
        results.append(entry)

    total = len(results)
    acc = sum(item["submit_metric"] for item in results) / total
    print(f"samples: {total}")
    print(f"accuracy: {acc:.4f}")
    if args.prm:
        prm_mean = sum(item["prm_metric"] for item in results) / total
        print(f"prm_mean: {prm_mean:.4f}")

    if args.output:
        output_path = Path(args.output)
        output_path.parent.mkdir(parents=True, exist_ok=True)
        with output_path.open("w", encoding="utf-8") as handle:
            for row in results:
                handle.write(json.dumps(row, ensure_ascii=True) + "\n")

    return 0


if __name__ == "__main__":
    raise SystemExit(main())
