#!/usr/bin/env python3
"""
从训练过程里的 valid jsonl 读取样本，按 AIME24 原题匹配，再统计按难度的准确率。

默认假设 valid 文件的 item 至少有:
- input:  prompt 文本
- output: 模型输出

示例:
python eval_scripts/analysis/acc_by_difficulty_valid.py \
    /mnt/shared-storage-gpfs2/p1-shared-2/wangfuting/LIE/models/verl-qwen3-4b-oct/sft-gspo-ours-dapo-math-max12k/valid/290_16384.jsonl
"""

from __future__ import annotations

import argparse
import csv
import json
import re
import sys
from collections import defaultdict
from difflib import SequenceMatcher
from pathlib import Path
from typing import Any


sys.path.append(str(Path(__file__).resolve().parents[1]))
try:
    from oat_math_grader import boxed_reward_fn
    BOXED_GRADER_IMPORT_ERROR = None
except Exception as e:  # pragma: no cover
    boxed_reward_fn = None
    BOXED_GRADER_IMPORT_ERROR = e


DEFAULT_AIME24_REFERENCE_PATH = (
    "/mnt/shared-storage-user/p1-shared/wangfuting/codes/"
    "project_tts_extrapolation/data/aime24_nofigures.jsonl"
)

DIFFICULTY_CLASSES = {
    "Easy": [9, 0, 7, 12],
    "Medium": [24, 11, 8, 15, 26, 6, 19, 18, 23, 22],
    "Hard": [10, 14, 17, 27, 16, 25, 4, 5, 1, 20, 28, 13],
    "Extremely Hard": [29, 2, 21, 3],
}
DIFFICULTY_ORDER = ["Easy", "Medium", "Hard", "Extremely Hard"]
QUESTION_TO_DIFFICULTY = {
    question_idx: difficulty
    for difficulty, question_indices in DIFFICULTY_CLASSES.items()
    for question_idx in question_indices
}


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="从训练 valid 文件中匹配 AIME24 题目，并按难度统计准确率。"
    )
    parser.add_argument("valid_file", help="训练过程里的 valid jsonl 文件路径")
    parser.add_argument(
        "--reference-path",
        default=DEFAULT_AIME24_REFERENCE_PATH,
        help="AIME24 参考 jsonl 路径",
    )
    parser.add_argument(
        "--match-threshold",
        type=float,
        default=0.45,
        help="模糊匹配阈值，默认 0.90",
    )
    parser.add_argument(
        "--output-csv",
        default="acc_by_difficulty_valid.csv",
        help="可选：导出每道题统计结果到 csv",
    )
    parser.add_argument(
        "--show-unmatched",
        type=int,
        default=5,
        help="打印多少个未匹配题目组，默认 5",
    )
    parser.add_argument(
        "--top-candidates",
        type=int,
        default=3,
        help="每个未匹配题目显示几个最相近的参考题，默认 3",
    )
    parser.add_argument(
        "--dump-unmatched-json",
        default=None,
        help="可选：将未匹配诊断信息保存成 json",
    )
    return parser.parse_args()


def read_jsonl(path: str | Path) -> list[dict[str, Any]]:
    with open(path, "r", encoding="utf-8") as f:
        return [json.loads(line) for line in f if line.strip()]


def normalize_text(value: Any) -> str:
    if value is None:
        return ""
    return str(value).replace("\r\n", "\n").strip()


def remove_asy_blocks(text: str) -> str:
    return re.sub(r"\[asy\].*?\[/asy\]", " ", text, flags=re.IGNORECASE | re.DOTALL)


def normalize_problem_text(text: str) -> str:
    text = normalize_text(text)
    text = text.replace("<|im_end|>", " ")
    text = text.replace("<|endoftext|>", " ")
    text = remove_asy_blocks(text)
    text = re.sub(r"\s+", " ", text)
    text = re.sub(r"\s+([,.;:!?])", r"\1", text)
    text = re.sub(r"([(\[{])\s+", r"\1", text)
    text = re.sub(r"\s+([)\]}])", r"\1", text)
    return text.strip()


def build_problem_match_keys(text: str) -> list[str]:
    normalized = normalize_problem_text(text)
    if not normalized:
        return []

    keys = [normalized]
    no_space = re.sub(r"\s+", "", normalized)
    if no_space != normalized:
        keys.append(no_space)
    return keys


def extract_problem_from_input(raw_text: Any) -> str:
    text = normalize_text(raw_text)

    for marker in ("<|im_start|>user\n", "user\n"):
        if marker in text:
            text = text.rsplit(marker, 1)[-1]
            break

    for marker in (
        "\n<|im_start|>assistant",
        "\nassistant\n",
        "\nassistant",
    ):
        if marker in text:
            text = text.split(marker, 1)[0]
            break

    text = re.sub(
        r"\s*Let's think step by step and output the final answer within\s*\\boxed\{\}\.?\s*$",
        "",
        text,
    )
    return text.strip()


def extract_prompt(item: dict[str, Any]) -> str:
    for key in ("input", "prompt", "question"):
        if key in item and item[key] is not None:
            return extract_problem_from_input(item[key])
    return ""


def extract_outputs(item: dict[str, Any]) -> list[str]:
    for key in ("output", "response", "generated_text", "responses"):
        if key not in item or item[key] is None:
            continue
        value = item[key]
        if isinstance(value, list):
            outputs = [normalize_text(v) for v in value if normalize_text(v)]
            if outputs:
                return outputs
        else:
            output = normalize_text(value)
            if output:
                return [output]
    return []


def extract_reference_problem(item: dict[str, Any]) -> str:
    for key in ("problem", "question", "prompt", "input"):
        if key in item and item[key] is not None:
            return normalize_text(item[key])
    raise KeyError(f"参考数据里没有题目字段: {item.keys()}")


def extract_reference_answer(item: dict[str, Any]) -> Any:
    for key in ("answer", "ground_truth", "target", "solution"):
        if key in item and item[key] not in (None, ""):
            return item[key]
    raise KeyError(f"参考数据里没有答案字段: {item.keys()}")


def extract_last_boxed_answer(text: Any) -> str | None:
    text = normalize_text(text)
    start = text.rfind("\\boxed")
    if start == -1:
        return None

    brace_start = text.find("{", start)
    if brace_start == -1:
        return None

    depth = 0
    for idx in range(brace_start, len(text)):
        if text[idx] == "{":
            depth += 1
        elif text[idx] == "}":
            depth -= 1
            if depth == 0:
                return text[brace_start + 1: idx].strip()
    return None


def normalize_answer_for_fallback(answer: Any) -> str:
    if isinstance(answer, (int, float)):
        answer = str(answer)
    elif isinstance(answer, list):
        answer = answer[0] if answer else ""

    text = normalize_text(answer)
    if "\\boxed" in text:
        boxed = extract_last_boxed_answer(text)
        if boxed is not None:
            text = boxed

    text = text.strip().strip("$").strip()
    text = re.sub(r"^\{|\}$", "", text)
    if re.fullmatch(r"-?\d+\.0+", text):
        text = text.split(".", 1)[0]
    return text


def fallback_is_correct(output: str, answer: Any) -> bool:
    model_answer = extract_last_boxed_answer(output)
    if model_answer is None:
        return False
    return normalize_answer_for_fallback(model_answer) == normalize_answer_for_fallback(answer)


def load_aime24_reference(reference_path: str | Path) -> tuple[
    list[dict[str, Any]],
    dict[str, int],
    list[tuple[str, int]],
]:
    reference_items = read_jsonl(reference_path)

    exact_lookup: dict[str, int] = {}
    problems_for_fuzzy_match: list[tuple[str, int]] = []
    for idx, item in enumerate(reference_items):
        normalized_problem = normalize_problem_text(extract_reference_problem(item))
        for key in build_problem_match_keys(normalized_problem):
            exact_lookup[key] = idx
        problems_for_fuzzy_match.append((normalized_problem, idx))

    return reference_items, exact_lookup, problems_for_fuzzy_match


def build_reference_question_info(
    question_idx: int,
    reference_item: dict[str, Any],
) -> dict[str, Any]:
    problem = extract_reference_problem(reference_item)
    return {
        "question_idx": question_idx,
        "reference_id": reference_item.get("id", question_idx),
        "difficulty": QUESTION_TO_DIFFICULTY.get(question_idx, "Unknown"),
        "problem": problem,
        "problem_snippet": problem[:200],
    }


def find_top_reference_candidates(
    normalized_problem: str,
    reference_items: list[dict[str, Any]],
    problems_for_fuzzy_match: list[tuple[str, int]],
    top_k: int,
) -> list[dict[str, Any]]:
    scored_candidates = []
    normalized_problem_no_space = re.sub(r"\s+", "", normalized_problem)
    for reference_problem, idx in problems_for_fuzzy_match:
        score = max(
            SequenceMatcher(None, normalized_problem, reference_problem).ratio(),
            SequenceMatcher(
                None,
                normalized_problem_no_space,
                re.sub(r"\s+", "", reference_problem),
            ).ratio(),
        )
        scored_candidates.append((score, idx, reference_problem))

    scored_candidates.sort(key=lambda x: x[0], reverse=True)

    top_candidates = []
    for score, idx, _ in scored_candidates[:top_k]:
        reference_info = build_reference_question_info(
            idx, reference_items[idx])
        top_candidates.append(
            {
                "score": score,
                "question_idx": reference_info["question_idx"],
                "reference_id": reference_info["reference_id"],
                "difficulty": reference_info["difficulty"],
                "problem_snippet": reference_info["problem_snippet"],
            }
        )
    return top_candidates


def match_to_aime24(
    problem_text: str,
    exact_lookup: dict[str, int],
    problems_for_fuzzy_match: list[tuple[str, int]],
    threshold: float,
) -> tuple[int | None, float, str]:
    normalized_problem = normalize_problem_text(problem_text)
    if not normalized_problem:
        return None, 0.0, "empty"

    for key in build_problem_match_keys(normalized_problem):
        if key in exact_lookup:
            return exact_lookup[key], 1.0, "exact"

    best_idx = None
    best_score = 0.0
    normalized_problem_no_space = re.sub(r"\s+", "", normalized_problem)
    for reference_problem, idx in problems_for_fuzzy_match:
        score = max(
            SequenceMatcher(None, normalized_problem, reference_problem).ratio(),
            SequenceMatcher(
                None,
                normalized_problem_no_space,
                re.sub(r"\s+", "", reference_problem),
            ).ratio(),
        )
        if score > best_score:
            best_score = score
            best_idx = idx

    if best_idx is not None and best_score >= threshold:
        return best_idx, best_score, "fuzzy"
    return None, best_score, "unmatched"


def is_correct(output: str, answer: Any) -> bool:
    if boxed_reward_fn is not None:
        try:
            _, reward = boxed_reward_fn(output, answer, fast=False)
            return reward > 0.5
        except Exception:
            pass
    return fallback_is_correct(output, answer)


def analyze_valid_file(
    valid_file: str | Path,
    reference_path: str | Path,
    match_threshold: float,
    top_candidates: int,
) -> tuple[dict[int, dict[str, Any]], dict[str, Any]]:
    reference_items, exact_lookup, problems_for_fuzzy_match = load_aime24_reference(
        reference_path
    )
    valid_items = read_jsonl(valid_file)

    question_stats: dict[int, dict[str, Any]] = {}
    unmatched_groups: dict[str, dict[str, Any]] = {}

    matched_item_count = 0
    matched_output_count = 0
    exact_match_count = 0
    fuzzy_match_count = 0

    for line_number, item in enumerate(valid_items, start=1):
        prompt = extract_prompt(item)
        if not prompt:
            continue

        question_idx, match_score, match_type = match_to_aime24(
            prompt,
            exact_lookup,
            problems_for_fuzzy_match,
            match_threshold,
        )
        if question_idx is None:
            normalized_prompt = normalize_problem_text(prompt)
            group = unmatched_groups.setdefault(
                normalized_prompt,
                {
                    "count": 0,
                    "line_numbers": [],
                    "best_score": match_score,
                    "raw_problem": prompt,
                    "normalized_problem": normalized_prompt,
                    "top_candidates": find_top_reference_candidates(
                        normalized_prompt,
                        reference_items,
                        problems_for_fuzzy_match,
                        top_candidates,
                    ),
                },
            )
            group["count"] += 1
            group["best_score"] = max(group["best_score"], match_score)
            if len(group["line_numbers"]) < 10:
                group["line_numbers"].append(line_number)
            continue

        outputs = extract_outputs(item)
        if not outputs:
            continue

        matched_item_count += 1
        matched_output_count += len(outputs)
        if match_type == "exact":
            exact_match_count += 1
        else:
            fuzzy_match_count += 1

        reference_item = reference_items[question_idx]
        answer = extract_reference_answer(reference_item)

        stats = question_stats.setdefault(
            question_idx,
            {
                "question_idx": question_idx,
                "reference_id": reference_item.get("id", question_idx),
                "difficulty": QUESTION_TO_DIFFICULTY.get(question_idx, "Unknown"),
                "total": 0,
                "correct": 0,
                "match_types": defaultdict(int),
            },
        )
        stats["match_types"][match_type] += 1

        for output in outputs:
            stats["total"] += 1
            stats["correct"] += int(is_correct(output, answer))

    for stats in question_stats.values():
        stats["accuracy"] = stats["correct"] / \
            stats["total"] if stats["total"] else 0.0
        stats["match_types"] = dict(stats["match_types"])

    missing_question_indices = sorted(
        set(range(len(reference_items))) - set(question_stats.keys())
    )
    missing_questions = [
        build_reference_question_info(
            question_idx, reference_items[question_idx])
        for question_idx in missing_question_indices
    ]
    unmatched_group_list = sorted(
        unmatched_groups.values(),
        key=lambda x: (-x["count"], x["line_numbers"][0]
                       if x["line_numbers"] else 10**9),
    )

    summary = {
        "valid_file": str(valid_file),
        "total_items": len(valid_items),
        "matched_item_count": matched_item_count,
        "matched_output_count": matched_output_count,
        "matched_question_count": len(question_stats),
        "exact_match_count": exact_match_count,
        "fuzzy_match_count": fuzzy_match_count,
        "missing_questions": missing_questions,
        "unmatched_groups": unmatched_group_list,
    }
    return question_stats, summary


def build_difficulty_summary(
    question_stats: dict[int, dict[str, Any]]
) -> list[dict[str, Any]]:
    rows = []
    overall_total = 0
    overall_correct = 0

    for difficulty in DIFFICULTY_ORDER:
        question_indices = DIFFICULTY_CLASSES[difficulty]
        total = 0
        correct = 0
        matched_questions = 0

        for question_idx in question_indices:
            stats = question_stats.get(question_idx)
            if not stats:
                continue
            matched_questions += 1
            total += stats["total"]
            correct += stats["correct"]

        overall_total += total
        overall_correct += correct
        accuracy = correct / total if total else 0.0
        rows.append(
            {
                "difficulty": difficulty,
                "correct": correct,
                "total": total,
                "accuracy": accuracy,
                "matched_questions": matched_questions,
            }
        )

    rows.append(
        {
            "difficulty": "Overall",
            "correct": overall_correct,
            "total": overall_total,
            "accuracy": overall_correct / overall_total if overall_total else 0.0,
            "matched_questions": len(question_stats),
        }
    )
    return rows


def write_question_stats_csv(
    output_csv: str | Path,
    question_stats: dict[int, dict[str, Any]],
) -> None:
    rows = sorted(
        question_stats.values(),
        key=lambda x: (
            DIFFICULTY_ORDER.index(x["difficulty"])
            if x["difficulty"] in DIFFICULTY_ORDER
            else len(DIFFICULTY_ORDER),
            x["question_idx"],
        ),
    )

    with open(output_csv, "w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(
            f,
            fieldnames=[
                "question_idx",
                "reference_id",
                "difficulty",
                "correct",
                "total",
                "accuracy",
                "match_types",
            ],
        )
        writer.writeheader()
        for row in rows:
            writer.writerow(
                {
                    "question_idx": row["question_idx"],
                    "reference_id": row["reference_id"],
                    "difficulty": row["difficulty"],
                    "correct": row["correct"],
                    "total": row["total"],
                    "accuracy": f"{row['accuracy']:.6f}",
                    "match_types": json.dumps(row["match_types"], ensure_ascii=False),
                }
            )


def write_json(path: str | Path, payload: Any) -> None:
    with open(path, "w", encoding="utf-8") as f:
        json.dump(payload, f, ensure_ascii=False, indent=2)


def main() -> None:
    args = parse_args()

    question_stats, summary = analyze_valid_file(
        valid_file=args.valid_file,
        reference_path=args.reference_path,
        match_threshold=args.match_threshold,
        top_candidates=args.top_candidates,
    )
    difficulty_rows = build_difficulty_summary(question_stats)

    print(f"valid file: {summary['valid_file']}")
    print(f"total items: {summary['total_items']}")
    if BOXED_GRADER_IMPORT_ERROR is not None:
        print(
            f"warning: fallback grader enabled because oat_math_grader import failed: {BOXED_GRADER_IMPORT_ERROR}")
    print(
        "matched AIME items: "
        f"{summary['matched_item_count']} "
        f"(outputs: {summary['matched_output_count']})"
    )
    print(
        "matched questions: "
        f"{summary['matched_question_count']}/30 | "
        f"exact: {summary['exact_match_count']} | "
        f"fuzzy: {summary['fuzzy_match_count']}"
    )

    if summary["missing_questions"]:
        print(
            f"\nMissing reference questions: {len(summary['missing_questions'])}")
        for item in summary["missing_questions"]:
            print(
                f"q{item['question_idx']:>2} | "
                f"{item['difficulty']:<14} | "
                f"id={item['reference_id']} | "
                f"{item['problem_snippet']}"
            )

    print("\nAccuracy by difficulty")
    for row in difficulty_rows:
        print(
            f"{row['difficulty']:<16} "
            f"{row['correct']:>4}/{row['total']:<4} "
            f"= {row['accuracy']:.4f} "
            f"(matched questions: {row['matched_questions']})"
        )

    print("\nPer-question accuracy")
    for question_idx in sorted(
        question_stats,
        key=lambda idx: (
            DIFFICULTY_ORDER.index(QUESTION_TO_DIFFICULTY.get(idx, "Unknown"))
            if QUESTION_TO_DIFFICULTY.get(idx, "Unknown") in DIFFICULTY_ORDER
            else len(DIFFICULTY_ORDER),
            idx,
        ),
    ):
        row = question_stats[question_idx]
        print(
            f"q{row['question_idx']:>2} | "
            f"{row['difficulty']:<14} | "
            f"{row['correct']:>3}/{row['total']:<3} | "
            f"{row['accuracy']:.4f} | "
            f"id={row['reference_id']}"
        )

    if args.show_unmatched > 0 and summary["unmatched_groups"]:
        print(
            f"\nTop {min(args.show_unmatched, len(summary['unmatched_groups']))} unmatched groups"
        )
        for group in summary["unmatched_groups"][: args.show_unmatched]:
            print(
                f"count={group['count']} | "
                f"lines={group['line_numbers']} | "
                f"best_score={group['best_score']:.4f}"
            )
            print(f"raw: {group['raw_problem'][:300]}")
            print(f"normalized: {group['normalized_problem'][:300]}")
            for idx, candidate in enumerate(group["top_candidates"], start=1):
                print(
                    f"  candidate {idx}: "
                    f"q{candidate['question_idx']:>2} | "
                    f"{candidate['difficulty']:<14} | "
                    f"id={candidate['reference_id']} | "
                    f"score={candidate['score']:.4f} | "
                    f"{candidate['problem_snippet']}"
                )

    if args.dump_unmatched_json:
        write_json(
            args.dump_unmatched_json,
            {
                "missing_questions": summary["missing_questions"],
                "unmatched_groups": summary["unmatched_groups"],
            },
        )
        print(f"\nSaved unmatched diagnostics to: {args.dump_unmatched_json}")

    if args.output_csv:
        write_question_stats_csv(args.output_csv, question_stats)
        print(f"\nSaved question stats to: {args.output_csv}")


if __name__ == "__main__":
    main()
