import re
import timeout_decorator
import sympy as sp
from sympy.parsing.latex import parse_latex
from sympy.printing import latex
from collections import Counter
import os
import sys
from pathlib import Path
import json
from sympy.parsing.sympy_parser import (
    parse_expr,
    standard_transformations,
    implicit_multiplication_application,
)

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.normpath(os.path.join(BASE_DIR, ".."))

if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

from utils.math_parse import strip_string
from utils.math_utils import _num_equals

MATRIX_ENVS = ("pmatrix", "bmatrix", "vmatrix", "Vmatrix", "matrix", "smallmatrix")

# 捕获 \command 的“命令保护”机制，避免拆行时吞掉命令的反斜杠
# ---------- 命令保护 / 还原 ----------
_CMD_PROTECT_RE = re.compile(r'\\([a-zA-Z]+)')  # \frac, \sin, \left, \right, ...

def _protect_cmds(s: str) -> str:
    # \foo -> ⟦CMD:foo⟧
    return _CMD_PROTECT_RE.sub(lambda m: f'⟦CMD:{m.group(1)}⟧', s)

def _restore_cmds(s: str) -> str:
    # ⟦CMD:foo⟧ -> \foo
    return re.sub(r'⟦CMD:([a-zA-Z]+)⟧', r'\\\1', s)

# ---------- 分数归一化 ----------
_inline_frac_pat = re.compile(
    r'(?<!\\)(?<!\w)'
    r'(-?(?:\d+(?:\.\d+)?))\s*/\s*(-?(?:\d+(?:\.\d+)?))'
    r'(?!\w)'
)

def _normalize_inline_frac(s: str) -> str:
    # 用函数替换，避免 "\f" 被当作换页符
    def repl(m):
        return "\\frac{%s}{%s}" % (m.group(1), m.group(2))
    return _inline_frac_pat.sub(repl, s)

def normalize_tex(s: str) -> str:
    s = s.strip()
    s = s.replace(r'\tfrac', r'\frac').replace(r'\dfrac', r'\frac')
    s = _normalize_inline_frac(s)
    s = re.sub(r'\s+', ' ', s)
    return s

# ---------- 单斜杠异常的兜底 ----------
def _normalize_single_slash_rows(body: str) -> str:
    """
    若 body 中没有合法行分隔 '\\'，把“单个 '\' 作为行分隔”的常见异常修成 '\\'。
    左侧看作单元结束（数字/右括/右花），右侧看作新单元起点（可带符号）。
    避免把 '\frac' 等命令或 '\ ' 空白误改。
    """
    if not re.search(r"(?<!\\)\\\\", body):
        # \ + 数字：行分隔，数字保留到下一行
        body = re.sub(
            r"(?<=[0-9)\}\]])\s*\\\s*(\d+)\s*(?=(?:\\|\d|\(|\{|\w|-|\+))",
            lambda m: "⟦ROW⟧" + m.group(1),
            body,
        )
        # \ + 符号(+/-)：行分隔，符号保留到下一行
        body = re.sub(
            r"(?<=[0-9)\}\]])\s*\\\s*([+\-])\s*(?=(?:\\|\d|\(|\{|\w))",
            lambda m: "⟦ROW⟧" + m.group(1),
            body,
        )
    return body

# ---------- 安全拆行：兼容 \\\frac、\\、以及 “\ + 符号 + 命令/数字/括号” ----------
def _split_rows_safely(body: str):
    tmp = _protect_cmds(body)

    # 先把可能已有的行分隔标记压缩成单个
    tmp = re.sub(r'(?:⟦ROW⟧\s*)+', '⟦ROW⟧', tmp)

    # 2) \\\frac：只吃前两杠，保留一个给命令
    tmp = re.sub(r'\\{3}(?=(?:⟦CMD:|[A-Za-z\\]))', lambda m: '⟦ROW⟧\\', tmp)

    # 3) 任何连续 >=2 的反斜杠（含 \\ 与 \\\\…）统一视为一个行分隔（避免重叠替换）
    tmp = re.sub(r'\\{2,}(?:\[[^\]]*\])?', lambda m: '⟦ROW⟧', tmp)

    # 4a) 单斜杠 + 命令（异常数据） -> 行分隔
    tmp = re.sub(r'\\(?=\s*⟦CMD:)', lambda m: '⟦ROW⟧', tmp)

    # 4b) 单斜杠 + 符号(+/-) + 下一项 -> 行分隔，符号保留到下一行
    tmp = re.sub(
        r'\\\s*([+\-])\s*(?=(?:⟦CMD:|\\|\d|\(|\{|\w))',
        lambda m: '⟦ROW⟧' + m.group(1),
        tmp
    )

    # 4b') 单斜杠 + 数字 + 下一项 -> 行分隔，数字保留到下一行
    tmp = re.sub(
        r'\\\s*(\d+)\s*(?=(?:⟦CMD:|\\|\d|\(|\{|\w|-|\+))',
        lambda m: '⟦ROW⟧' + m.group(1),
        tmp
    )

    # 4c) “单元结束→新单元开始”的裸斜杠 -> 行分隔
    tmp = re.sub(
        r'(?<=[0-9)\}\]])\s*\\\s*(?=(?:⟦CMD:|\\|\d|\(|\{|\w))',
        lambda m: '⟦ROW⟧',
        tmp
    )

    # 再压一次，避免产生连续标记导致空行
    tmp = re.sub(r'(?:⟦ROW⟧\s*)+', '⟦ROW⟧', tmp)

    rows = [r.strip() for r in tmp.split('⟦ROW⟧') if r.strip()]
    rows = [_restore_cmds(r) for r in rows]
    return rows

# 在 \begin{...} 左侧紧贴地识别“标量因子”
_SCALAR_LEFT_PAT = r"""
    (?P<scalar>                      # 捕获标量
        [+\-]?\s*(
            \\frac\s*\{[^{}]+\}\s*\{[^{}]+\}   # \frac{...}{...}
            | \d+(?:\.\d+)?                    # 或普通数字
        )
    )
    \s*                               # 可选空白
"""

# 2) 提取“(可选标量) + 矩阵环境”的第一个命中
def _find_matrix_with_optional_scalar(s: str):
    envs = r"(?:pmatrix|bmatrix|vmatrix|Vmatrix|matrix|smallmatrix)\*?"
    # 先优先匹配“标量 + 矩阵”
    pat_scalar = re.compile(
        rf"{_SCALAR_LEFT_PAT}\\begin\{{(?P<env>{envs})\}}(?:\[[^\]]*\])?(?P<body>[\s\S]*?)\\end\{{(?P=env)\}}",
        flags=re.S | re.X
    )
    m = pat_scalar.search(s)
    if m:
        return m.group("env"), m.group("body"), m.group("scalar")
    # 退化为“裸矩阵”
    pat_bare = re.compile(
        rf"\\begin\{{(?P<env>{envs})\}}(?:\[[^\]]*\])?(?P<body>[\s\S]*?)\\end\{{(?P=env)\}}",
        flags=re.S
    )
    m = pat_bare.search(s)
    if m:
        return m.group("env"), m.group("body"), None
    return None

def _try_parse_matrix(expr_str: str):
    s = expr_str.strip()
    # print(s)
    # begin 和 end 需要支持 *
    # m = re.match(r"\\begin\{(\w*matrix)\*?\}(?:\[[^\]]*\])?([\s\S]*?)\\end\{\1\}\s*\Z", s, flags=re.S)
    found = _find_matrix_with_optional_scalar(s)
    if not found:
        return None

    env, body, scalar_tex = found

    body = _normalize_single_slash_rows(body)

    # print("_normalize_single_slash_rows:", body)

    # 按未转义的 '\\' 切分成行
    rows = _split_rows_safely(body)
    # print(rows)
    mat = []
    for row in rows:
        # 按未转义的 '&' 切分成列（列向量时没有 &）
        cols = [c.strip() for c in re.split(r"(?<!\\)&", row)]
        # print(cols)
        parsed_row = []
        for c in cols:
            # 对每个单元独立清洗+解析
            c_clean = normalize_tex(c)
            c_clean = clean_expr_str(c_clean)
            # print(c_clean)
            try:
                parsed_cell = my_parse_latex(c_clean)
            except Exception:
                # 再次兜底：直接尝试原串
                try:
                    parsed_cell = my_parse_latex(c)
                except Exception:
                    return None
            parsed_row.append(parsed_cell)
        mat.append(parsed_row)
    try:
        M = sp.Matrix(mat)
    except Exception:
        return None

    # 若左侧有标量，解析后相乘
    if scalar_tex:
        try:
            scalar_clean = clean_expr_str(normalize_tex(scalar_tex))
            scalar = my_parse_latex(scalar_clean)
            # 标量左乘/右乘等价；用左乘以避免形状歧义
            M = scalar * M
        except Exception:
            # 标量解析失败也不致命：退化为裸矩阵
            pass

    return M

# 比较两个数值答案
def compare_numerical_ans(ans_p, ans_l):
    if ans_p is None:
        return False
    ans_p = ans_p.replace(",", "").replace("$", "")
    ans_l = ans_l.replace(",", "").replace("$", "")
    try:
        if ans_p.endswith("%"):
            ans_p = float(ans_p.rstrip("%")) / 100
        if isinstance(ans_p, str):
            ans_p = float(ans_p)
        if isinstance(ans_l, str):
            ans_l = float(ans_l)
    except Exception as e:
        return False
    return abs(ans_p - float(ans_l)) < 1e-3

def my_parse_latex(expr_str):
    expr_str = expr_str.replace("dfrac", "frac")
    expr_str = expr_str.replace("π", "\\pi")
    expr = parse_latex(expr_str)
    if "\\pi" in expr_str:
        expr = expr.subs({sp.Symbol("pi"): sp.pi})
    expr = expr.subs({sp.Symbol("i"): sp.I})
    return expr

# 判断是否为数字
def is_number(element: str) -> bool:
    try:
        float(element.replace(" ", ""))
        return True
    except ValueError:
        return False

# 将百分数转换为分数
def percentage_to_fraction(text):
    # ① 正则只扩到支持可选空格 + 可选反斜杠： %, \%
    pattern = r'((?:\d+(?:\.\d+)?)\s*\\?%)'
    matches = re.findall(pattern, text)
    for percentage_str in matches:
        # ② 同时去掉 \% 和 %，只留下数字
        numeric = percentage_str.replace(r'\%', '%').replace('%', '').strip()
        try:
            val = float(numeric) / 100.0
        except ValueError:
            continue
        text = text.replace(percentage_str, str(val))
    return text

# 去除不影响答案的样式宏
STYLE_MACROS = (
    "text", "textbf", "textit", "textrm", "rm", "mathbb", "textnormal", "mbox",
    "mathrm", "mathbf", "mathit", "mathsf", "operatorname", "mathnormal", "overline"
)

# 清理表达式字符串
def clean_expr_str(expr_str):
    if expr_str is None:
        return ""
    # expr_str = re.sub(r"\\text\{(.*?)\}", r"\1", expr_str)

    expr_str = expr_str.strip()

    for _ in range(2):
        before = expr_str
        for m in STYLE_MACROS:
            expr_str = re.sub(rf"\\{m}\s*\{{([^{{}}]*)\}}", r"\1", expr_str)
        # 处理极少见的无花括号粗体单字：\textbf A
        expr_str = re.sub(r"\\textbf\s+([A-Za-z])\b", r"\1", expr_str)
        if expr_str == before:
            break

    # for m in STYLE_MACROS:
    #     expr_str = re.sub(r"\\" + m +r"\{(.*?)\}", r"\1", expr_str)

    expr_str = (
        expr_str.replace(" . ", ".")
        .replace(". ", ".")
        .replace("**", "^")
        .replace("\\pm", "")
        # .replace("*", "\\times ")
        # .replace("\\\\", "\\")
        .replace("\\ne ", "\\neq ")
        .replace("!=", "\\neq")
        .replace(">=", "\\ge")
        .replace("<=", "\\le")
        .replace("≠", "\\neq")
        .replace("dfrac", "frac")
        .replace("tfrac", "frac")
        .replace("\\$", "")
        .replace("$", "")
        .replace("\\%", "")
        .replace("%", "")
        .replace("\\!", "")
        # .replace("^\circ", "\\times \\pi / 180")
        .replace("//", "/")
        .replace('"', "")
        # .replace(",", "") # TODO
    )
    # expr_str = re.sub(r"\^\s(.*)", r"\^\s{\1}", expr_str)
    # expr_str = re.sub(r"\\+", r"\\", expr_str)

    # 将 * 替换为 \\times
    expr_str = re.sub(r"(?<=[0-9A-Za-z\}\)])\s*\*\s*(?=[0-9A-Za-z\\\(])", r"\\times ", expr_str)
    expr_str = re.sub(r"\^\s?\((.*?)\)", r"^{\1}", expr_str)
    expr_str = re.sub(r"\\frac\s?(\d)\s?(\d+)", r"\\frac{\1}{\2}", expr_str)
    expr_str = re.sub(r"\\log_\s?(\d)\s?(\d+)", r"\\log_{\1}{\2}", expr_str)
    expr_str = re.sub(r"\\frac\s?{(.*?)}\s?(\d)", r"\\frac{\1}{\2}", expr_str)
    expr_str = re.sub(r"\\frac\s?(\d)\s?{(.*?)}", r"\\frac{\1}{\2}", expr_str)
    expr_str = re.sub(r"\\sqrt\s?(\d)", r"\\sqrt{\1}", expr_str)
    expr_str = re.sub(r"sqrt\s?\((\d+)\)", r"\\sqrt{\1}", expr_str)
    expr_str = re.sub(r"sqrt\s?\((.*?)\)", r"\\sqrt{\1}", expr_str)
    # 任意 token 版 \frac 归一
    expr_str = re.sub(r"\\frac\s+([^\s{}]+)\s+([^\s{}]+)", r"\\frac{\1}{\2}", expr_str)
    expr_str = expr_str.replace(" sqrt", "\\sqrt")
    # 更稳妥的 \sqrt 无括号写法兜底
    expr_str = re.sub(r"\\sqrt\s+([^\s{}])", r"\\sqrt{\1}", expr_str)
    expr_str = (
        expr_str.replace("\\left", "").replace("\\right.", "").replace("\\right", "")
    )
    expr_str = re.sub(r'\\frac(?!\s*\{)\s*([^\s{}]+)\s*([^\s{}]+)', r'\\frac{\1}{\2}', expr_str)

    # 处理集合数据
    expr_str = expr_str.replace(r'\{', '{').replace(r'\}', '}')

    # 数字紧跟根号需要添加乘号
    expr_str = re.sub(r'(?<=\d)\\sqrt', r'\\cdot \\sqrt', expr_str)

    # 处理角度符号
    expr_str = re.sub(r'(?<=\d)\s*\^\s*\\circ\b', r'\\times (\\pi/180)', expr_str)

    # 规范化下标
    expr_str = re.sub(r'(\d+)_([A-Za-z0-9]+)', r'{\1}_{\2}', expr_str)

    expr_str = expr_str.replace("−", "-")
    return expr_str

# 区间解析
_SEQ_RE = re.compile(r"^\s*([\(\[])\s*(.+)\s*([\)\]])\s*$", re.S)
def _split_top_commas(s: str) -> list[str]:
    out, buf, depth_brace = [], [], 0
    depth_brack = depth_paren = 0
    for ch in s:
        if ch == '{': depth_brace += 1
        elif ch == '}': depth_brace -= 1
        elif ch == '[': depth_brack += 1
        elif ch == ']': depth_brack -= 1
        elif ch == '(': depth_paren += 1
        elif ch == ')': depth_paren -= 1

        if ch == ',' and depth_brace == depth_brack == depth_paren == 0:
            out.append(''.join(buf).strip())
            buf = []
        else:
            buf.append(ch)
    if buf:
        out.append(''.join(buf).strip())
    return out

def _build_interval(L, a, b, R):
    left_open  = (L == '(')
    right_open = (R == ')')
    try:
        # 不做端点交换；若 a>b 等无效写法，SymPy 会给出 EmptySet 或抛异常
        return sp.Interval(a, b, left_open=left_open, right_open=right_open)
    except Exception:
        # 保持解析成功，但语义是“空集”，后续比较自然判错
        return sp.EmptySet

def _try_parse_seq_or_interval(s):
    m = _SEQ_RE.match(s)
    if not m: return None
    L, body, R = m.groups()
    parts = _split_top_commas(body)
    # 区间：方/圆括混搭且恰好两段
    if len(parts) == 2:
        a = parse_latex_answer(parts[0]); b = parse_latex_answer(parts[1])
        if a is not None and b is not None:
            return _build_interval(L, a, b, R)

    # 特判 \cup 类范围
    if "\\cup" in s:
        union_list = s.split("\\cup")
        elems = [parse_latex_answer(p) for p in union_list]
        # print("elems", elems)
        if all(e is not None for e in elems):
            return sp.Set(*elems)

    # 否则当作元组
    if len(parts) >= 2:
        elems = [parse_latex_answer(p) for p in parts]
        if all(e is not None for e in elems):
            return sp.Tuple(*elems)
    return None

# 集合解析
_SET_RE = re.compile(r'^\s*\{(.+)\}\s*$', re.S)
def _try_parse_set(s):
    m = _SET_RE.match(s)
    if not m: return None
    parts = _split_top_commas(m.group(1))
    elems = [parse_latex_answer(p) for p in parts]
    if all(e is not None for e in elems):
        return sp.FiniteSet(*elems)
    return None

# 智能元组/列表比较函数【区间不属于这一类，所以这类比较是无序的】
def list_compare(list1, list2):
    """
    比较两个列表/元组，严格顺序比较
    """
    # 无序比较这里不用，因为我单独开了特判
    if len(list1) == len(list2):
        # 首先尝试严格顺序比较
        if all(is_expr_equal(parse_latex_answer(clean_expr_str(str(a))), parse_latex_answer(clean_expr_str(str(b)))) for a, b in zip(list1, list2)):
            return True
    return False

# 判断答案是否被括号包围，用于判断(1,2,4)和1,2,4之类的数据
def _is_bracketed(s: str) -> bool:
    return bool(_SEQ_RE.match(s or ""))

# 解析LaTeX表达式
def parse_latex_answer(sample):
    # 空串直接返回False
    if not sample or (isinstance(sample, str) and not sample.strip()):
        return None

    if isinstance(sample, (int, float)):
        return sp.nsimplify(sample)

    # 纯数字+单位解析
    # m = re.match(r'^\s*([+-]?(?:\d+(?:\.\d+)?|\.\d+))\s*(?:\\?\s*[A-Za-z\.]+)*\s*$', sample or '')
    # if m:
    #     return sp.nsimplify(re.sub(r'(?<!\d)\.(\d+)\b', r'0.\1', m.group(1)))

    # 专用于下标处理
    # NUM_SUB_RE = re.compile(r'^\s*(\d+)\s*_\s*\{?\s*([A-Za-z0-9]+)\s*\}?\s*$')
    # m = NUM_SUB_RE.match(sample)
    # if m:
    #     # 例：125_{2} -> 一个不会被当成数字的符号，占位并保留语义
    #     base_num = m.group(1)
    #     sub = m.group(2)
    #     return sp.Symbol(f"__NUMSUB__{base_num}__{sub}")

    # 先做一个轻量的预清洗，仅针对前导点小数
    if isinstance(sample, str):
        s_num = re.sub(r'(?<!\d)\.(\d+)\b', r'0.\1', sample.strip())
        if re.fullmatch(r'[+-]?(?:\d+|\d*\.\d+)', s_num):
            return sp.nsimplify(s_num)

    # if isinstance(sample, int) or isinstance(sample, float):
    #     sample = str(sample)
    #     return sample
    # sample = clean_expr_str(sample)
    # try:
    #     expr = my_parse_latex(sample)
    # except:
    #     print("[parse failed]", sample)
    #     return None
    # return expr

    # 区间解析
    m = _try_parse_seq_or_interval(sample)
    if m is not None:
        return m

    # 集合解析
    m = _try_parse_set(sample)
    if m is not None:
        return m

    # 兜底裸括号解析
    if isinstance(sample, str) and ',' in sample:
        # 排除千分位等情况，确保逗号仅在有效的数学表达式中存在
        if not re.search(r'\b\d{1,3}(,\d{3})+(\.\d+)?\b', sample):
            # 使用 _split_top_commas 分割出每个部分
            parts = _split_top_commas(sample)
            if len(parts) >= 2:
                # 关键修复：对每个部分分别进行LaTeX解析，而不是直接用sympify
                elems = []
                all_parsed = True
                for part in parts:
                    part = part.strip()
                    try:
                        # 尝试解析为LaTeX
                        elem = my_parse_latex(clean_expr_str(part))
                        elems.append(elem)
                    except Exception as e:
                        print(f"[parse failed] part: {part}, error: {e}")
                        all_parsed = False
                        break

                if all_parsed and elems:
                    return sp.Tuple(*elems)

    raw = sample
    sample = clean_expr_str(sample)

    # 矩阵处理
    if re.search(r"\\begin\{(?:" + "|".join(MATRIX_ENVS) + r")\*?\}", raw):
        mat = _try_parse_matrix(raw)  # 用 raw，避免清洗副作用
        if mat is not None:
            return mat
        # 兜底失败再试常规解析（带 begin 哨兵判定）
        for s in (sample, raw):
            try:
                expr = my_parse_latex(s)
                if getattr(expr, "is_Symbol", False) and str(expr) == "begin":
                    raise ValueError("sympy-begin-sentinel")
                return expr
            except Exception:
                pass
        print("[parse failed]", sample)
        return None

    # ② 非矩阵按原流程，但加入 begin 哨兵判定
    for s in (sample, raw):
        try:
            expr = my_parse_latex(s)
            if getattr(expr, "is_Symbol", False) and str(expr) == "begin":
                raise ValueError("sympy-begin-sentinel")
            return expr
        except Exception:
            pass
    print("[parse failed]", sample)
    return None

# 三角函数解析（将去除单位的结果统一解读为度数）
_TRIG_FUNCS = {sp.sin, sp.cos, sp.tan, sp.cot, sp.sec, sp.csc}
def _degify_trig(expr: sp.Expr) -> sp.Expr:
    """将形如 sin(32) / cos(58) 这类“纯数字且不含π”的三角函数实参
    视为度数，改写为 sin(pi*32/180) / cos(pi*58/180)。已含 pi 或含符号的不改。
    """
    def is_plain_number_no_pi(a: sp.Expr) -> bool:
        return not a.free_symbols and not a.has(sp.pi)
    return expr.replace(
        lambda e: getattr(e, "func", None) in _TRIG_FUNCS
                  and len(getattr(e, "args", ())) == 1
                  and is_plain_number_no_pi(e.args[0]),
        lambda e: e.func(sp.pi * e.args[0] / 180)
    )

def _expr_ops(e):
        """估计表达式复杂度，失败时当作极大。"""
        try:
            return int(sp.count_ops(e))
        except Exception:
            return 10**9

def _is_safe_combinatorics(e, max_ops: int = 40, max_n: int = 20) -> bool:
    """
    判断表达式中出现的 factorial/binomial/gamma 是否规模可控：
      - 整体复杂度不能太大（count_ops <= max_ops）
      - factorial(n) 里 n 必须是小整数（|n| <= max_n）
      - binomial(n, k) 里 n,k 也是“小整数”
      - gamma 一律谨慎处理：只放行 gamma(小整数) 这种情况
    """
    try:
        # 先看整体复杂度
        if _expr_ops(e) > max_ops:
            return False

        atoms = e.atoms(sp.factorial, sp.binomial, sp.gamma)
        if not atoms:
            return True

        for a in atoms:
            if isinstance(a, sp.factorial):
                if len(a.args) != 1:
                    return False
                n = a.args[0]
                # 只允许“小整数阶乘”
                if not (n.is_integer and n.is_number):
                    return False
                n_val = int(n)
                if abs(n_val) > 10:
                    return False

            elif isinstance(a, sp.binomial):
                if len(a.args) != 2:
                    return False
                n, k = a.args
                if not (n.is_integer and n.is_number and k.is_integer and k.is_number):
                    return False
                n_val, k_val = int(n), int(k)
                if abs(n_val) > max_n or abs(k_val) > max_n:
                    return False

            else:
                # gamma：更谨慎，只放行 gamma(小整数)
                if len(a.args) != 1:
                    return False
                z = a.args[0]
                if not (z.is_integer and z.is_number):
                    return False
                z_val = int(z)
                if abs(z_val) > 10:
                    return False

        # 所有出现的 factorial/binomial/gamma 都“小而简单”
        return True

    except Exception:
        # 检查炸了就当不安全
        return False

# 判断两个表达式是否相等
# def my_equals(ans_p, ans_l, tol=1e-8):
#     """
#     面向 MATH 的轻量答案判断版本：
#       - 小规模表达式：允许适度 simplify / trigsimp / Poly 检查
#       - 含 factorial / binomial / gamma / 超大表达式：直接放弃，判 False，但允许轻量的化简
#     """
#     # ---------- 0. 预处理 ----------
#     if ans_p is None or ans_l is None:
#         return False

#     ap, al = ans_p, ans_l
#     try:
#         eq = ap.equals(al)
#         if eq is True:
#             return True
#     except Exception:
#         pass

#     # 表达式转化
#     try:
#         ap = sp.sympify(ans_p)
#         al = sp.sympify(ans_l)
#     except Exception:
#         # 连 SymPy 表达式都构不出来，直接放弃等价判断
#         return False

#     # 完全结构相等判断等
#     if ap == al:
#         return True

#     diff = ap - al
#     # 差值已经是 0（或带 is_zero 属性）
#     try:
#         if diff == 0 or getattr(diff, "is_zero", False):
#             return True
#     except Exception:
#         pass

#     # 整体复杂度太大：直接放弃，避免任何重度操作
#     if _expr_ops(diff) > 120:
#         return False

#     allow_heavy_simplify = True
#     # 含 factorial/gamma/binomial，接受轻度化简
#     has_comb = bool(diff.atoms(sp.factorial, sp.binomial, sp.gamma))
#     if has_comb:
#         # 如果组合式不安全（太大、参数不规整），直接放弃
#         if not _is_safe_combinatorics(diff, max_ops=40, max_n=20):
#             return False
#         allow_heavy_simplify = False # 出现组合式后续就不做simplify

#     # ---------- 轻量代数化简 ----------
#     try:
#         if _expr_ops(diff) <= 80:
#             d = sp.expand(diff)
#             d = sp.together(d)
#             d = sp.cancel(d)
#             if allow_heavy_simplify:
#                 d = sp.simplify(d, ratio=2)
#             if d == 0 or getattr(d, "is_zero", False):
#                 return True
#     except Exception:
#         pass

#     # ---------- 轻量三角化简 ----------
#     try:
#         if diff.has(sp.sin, sp.cos, sp.tan, sp.asin, sp.acos, sp.atan):
#             if _expr_ops(diff) <= 60:
#                 d = sp.simplify(sp.trigsimp(ap - al))
#                 if d == 0 or getattr(d, "is_zero", False):
#                     return True
#     except Exception:
#         pass

#     # --------------- 轻量度制检验（还是需要的） -------------------
#     ap_deg = _degify_trig(ap)
#     al_deg = _degify_trig(al)
#     if (ap_deg != ap) or (al_deg != al):
#         # 先符号化简
#         if _expr_ops(diff) <= 60:
#             d = sp.simplify(sp.trigsimp(ap_deg - al_deg))
#             if d == 0 or getattr(d, "is_zero", False):
#                 return True
#             # 若无符号，直接数值比较
#             if not (ap_deg - al_deg).free_symbols:
#                 if abs(complex(sp.N(ap_deg - al_deg))) < tol:
#                     return True

#     # ---------- 多项式零检验（仅限中小表达式） ----------
#     try:
#         syms = sorted(diff.free_symbols, key=lambda s: s.name)
#         if syms and _expr_ops(diff) <= 40:
#             P = sp.Poly(sp.expand(diff), *syms, domain="EX")
#             if P.is_zero:
#                 return True
#     except Exception:
#         pass

#     # ---------- 4. 简单数值兜底 ----------
#     syms = list((ap - al).free_symbols)
#     if not syms:
#         return abs(complex(sp.N(ap - al))) < tol

#     return False

# 判断两个表达式是否相等
def my_equals(ans_p, ans_l, tol=1e-12):
    ap, al = ans_p, ans_l

    # 1) 快速路：equals，只在明确 True 时返回
    try:
        eq = ap.equals(al)
        if eq is True:
            return True
    except Exception:
        pass  # 不要在这里返回 False

    # 2) 代数化简流水线：展开→合并→约分→再化简
    try:
        diff = ap - al
        d = sp.expand(diff)
        d = sp.together(d)   # 合并到同一分母
        d = sp.cancel(d)     # 约掉公因子
        d = sp.simplify(d)
        if d == 0 or getattr(d, "is_zero", False):
            return True
    except Exception:
        pass

    # 3) 三角式再试一轮（非三角情况几乎无代价）
    try:
        d = sp.simplify(sp.trigsimp(ap - al))
        if d == 0 or getattr(d, "is_zero", False):
            return True
    except Exception:
        pass

    # ---------------度制兜底 -------------------
    try:
        ap_deg = _degify_trig(ap)
        al_deg = _degify_trig(al)
        if (ap_deg != ap) or (al_deg != al):
            # 先符号化简
            d = sp.simplify(sp.trigsimp(ap_deg - al_deg))
            if d == 0 or getattr(d, "is_zero", False):
                return True
            # 若无符号，直接数值比较
            if not (ap_deg - al_deg).free_symbols:
                if abs(complex(sp.N(ap_deg - al_deg))) < tol:
                    return True
    except Exception:
        pass
    # ----------------------------------------

    # 4) 可多项式化时做严格零检验
    try:
        syms = sorted((ap - al).free_symbols, key=lambda s: s.name)
        if syms:
            P = sp.Poly(sp.expand(ap - al), *syms, domain='EX')
            if P.is_zero:
                return True
    except Exception:
        pass

    # 5) 数值兜底：无自由元直接数值化，有自由元做多点代入
    try:
        syms = list((ap - al).free_symbols)
        if not syms:
            return abs(complex(sp.N(ap - al))) < tol
        # 几个固定点，避免偶然碰到极点；不要求随机
        for val in (1, 2, 3, 4, 5, 6):
            subs = {s: val for s in syms}
            v = (ap - al).subs(subs)
            if v.is_number:
                if abs(complex(sp.N(v))) > tol:
                    return False
            else:
                # 仍不数值就跳过该点
                continue
        return True
    except Exception:
        pass

    return False

# 判断两个表达式是否相等
def is_expr_equal(ans_p, ans_l, is_strict=False):
    def is_equ_num_equal(equation, number):
        if (
            isinstance(equation, sp.Eq)
            # and isinstance(equation.lhs, sp.Symbol)
            and equation.rhs.is_number
            and number.is_number
        ):
            try:
                ret = my_equals(equation.rhs, number)
                return bool(ret)
            except:
                return equation.rhs == number

    if ans_p is None or ans_l is None:
        return False

    # 优先区间
    if isinstance(ans_p, sp.Interval) and isinstance(ans_l, sp.Interval):
        return (_num_equals(ans_p.start, ans_l.start) and
                _num_equals(ans_p.end,   ans_l.end) and
                ans_p.left_open  == ans_l.left_open and
                ans_p.right_open == ans_l.right_open)

    # 再集合
    if isinstance(ans_p, sp.FiniteSet) and isinstance(ans_l, sp.FiniteSet):
        return ans_p == ans_l

    # 联合区间类型
    if isinstance(ans_p, sp.Set) and isinstance(ans_l, sp.Set):
        try:
            # 先规整一下形态，避免无关的表示差异
            ap = sp.simplify(ans_p)
            al = sp.simplify(ans_l)
            # 真·集合相等：对称差为空
            return ap.symmetric_difference(al) == sp.EmptySet
        except Exception:
            # 若解析失败，则直接逐个比对
            try:
                for p in ans_p:
                    if p not in ans_l:
                        return False
                for l in ans_l:
                    if l not in ans_p:
                        return False
                return True
            except Exception:
                return False

    # 元组/列表
    if isinstance(ans_p, (sp.Tuple, list, tuple)) and isinstance(ans_l, (sp.Tuple, list, tuple)):
        if len(ans_p) != len(ans_l): return False
        return list_compare(list(ans_p), list(ans_l))


    if (
        not is_strict
        and is_equ_num_equal(ans_l, ans_p)
        or is_equ_num_equal(ans_p, ans_l)
    ):
        return True

    if ans_p.free_symbols != ans_l.free_symbols:
        return False

    if ans_p == ans_l:
        return True

    if isinstance(ans_l, sp.core.relational.Relational):
        try:
            if (
                type(ans_l) == type(ans_p)
                and my_equals(ans_p.lhs, ans_l.lhs)
                and my_equals(ans_p.rhs, ans_l.rhs)
            ):
                return True
        except Exception as e:
            print(ans_p, ans_l, e)
    try:
        ret = my_equals(ans_p, ans_l)
        return bool(ret)
    except:
        return False

# 提取答案中的最后一个数字
def extract_answer_number(sentence: str) -> float:
    sentence = sentence.replace(",", "")
    pred = [s for s in re.findall(r"-?\d+\.?\d*", sentence)]
    if not pred:
        return ""
    return pred[-1]

# 提取答案中的唯一数字
def extract_only_number(sentence: str) -> float:
    sentence = sentence.replace(",", "")
    pred = [s for s in re.findall(r"-?\d+\.?\d*", sentence)]
    if not pred or len(pred) > 1:
        return ""
    return pred[0]

def _as_list_from_chain(s: str):
    if not isinstance(s, str): return None
    t = (s.replace(r'\leq', '<=').replace(r'\le', '<=').replace('≤', '<=')
           .replace(r'\geq', '>=').replace(r'\ge', '>=').replace('≥', '>=')
           .replace(r'\lt', '<').replace(r'\gt', '>'))
    if any(ch in t for ch in '<>'):
        parts = re.split(r'\s*(?:<=|>=|<|>)\s*', t.strip())
        if len(parts) >= 2:
            return [p.strip() for p in parts]
    return None

# 处理时间类型
_TIME_RE = re.compile(
    r'(?i)^\s*'                       # 前导空白
    r'(\d{1,2})\s*:\s*(\d{2})'        # HH:MM
    r'(?:\s*:\s*(\d{2}))?'            # 可选 :SS
    r'\s*(am|pm|a\.?m\.?|p\.?m\.?)?'  # AM/PM 的多种写法: am, pm, a.m, a.m., p.m, p.m.
    r'\s*$'
)
def _norm_time(s: str) -> str | None:
    if not isinstance(s, str): return None
    m = _TIME_RE.match(s.strip())
    if not m: return None
    h, mi, se, ap = m.groups()
    h, mi, se = int(h), int(mi), int(se or 0)
    if ap:
        ap = ap.lower().replace('.', '')
        if ap == 'pm' and h != 12: h += 12
        if ap == 'am' and h == 12: h = 0
    return f"{h:02d}:{mi:02d}:{se:02d}"

def _parse_time_raw(s: str):
    if not isinstance(s, str):
        return None
    m = _TIME_RE.match(s.strip())
    if not m:
        return None
    h, mi, se, ap = m.groups()
    h, mi, se = int(h), int(mi), int(se or 0)
    ap = (ap or "").lower().replace(".", "")
    ap = ap if ap in ("am", "pm") else None
    return h, mi, se, ap

# 纯数字提取
_NUM_TOKEN_RE = re.compile(r'[+-]?(?:\d+(?:\.\d+)?|\.\d+)')
def _looks_like_plain_number_or_with_unit(s: str) -> bool:
    """仅当 s 看起来就是一个“单一数字（可带单位词）”，且不含任何 LaTeX 宏/运算符/括号时返回 True。"""
    if not isinstance(s, str):
        return False
    t = s.strip()
    # 含有明显的数学/LaTeX 结构就直接否掉
    if re.search(r'[\\{}^_*=/]|[\(\)\[\]<>]|[,;]|(?:\bfrac\b|\bbegin\b|\bend\b)', t):
        return False
    # 去掉千分位逗号（如果你有这类数据）
    t = t.replace(',', '')
    nums = _NUM_TOKEN_RE.findall(t)
    # 必须恰好一个数字；其它内容允许是空白或字母（单位词）
    if len(nums) != 1:
        return False
    # 不能有除了字母/空白/货币符号这类“外观字符”以外的符号
    if re.search(r'[^0-9.\-\+\sA-Za-z¥€£$%]', t):
        return False
    return True

def _extract_single_number(s: str) -> str | None:
    m = _NUM_TOKEN_RE.findall(s.replace(',', ''))
    return m[0] if len(m) == 1 else None

# 判断一个字符串（去首尾标点后）是否是单个词。
STRIP_PUNCT_RE = re.compile(r'^[^\w\[\]()]+|[^\w\[\]()]+$')
WHITESPACE_RE = re.compile(r"\s")
ALPHA_WORD_RE = re.compile(r'^[a-zA-Z]+$')
def _strip_punct(s: str) -> str:
    return STRIP_PUNCT_RE.sub('', (s or "").strip())

def _is_single_word(s: str) -> bool:
    t = _strip_punct(s)
    return bool(t) and not WHITESPACE_RE.search(t) and ALPHA_WORD_RE.match(t)

# def _is_numsub_sym(e):
#     return isinstance(e, sp.Symbol) and str(e).startswith("__NUMSUB__")

# 比较两个答案
@timeout_decorator.timeout(10)
def compare_ans(ans_p_str, ans_l_str, is_strict=False):
    ans_p_str = percentage_to_fraction(ans_p_str or "")
    ans_l_str = percentage_to_fraction(ans_l_str or "")

    ans_p_str = clean_expr_str(ans_p_str)
    ans_p_str = ans_p_str.replace("$", "")
    ans_l_str = clean_expr_str(ans_l_str)
    ans_l_str = ans_l_str.replace("$", "")

    # 文本直接比较【虽然大部分无法配上，但是对标准格式和纯字符串的输出有奇效】
    if clean_expr_str(ans_p_str) == clean_expr_str(ans_l_str):
        return True
    else:
        # 专门用于处理单个单词（纯字母）的输出类型
        if _is_single_word(ans_p_str) and _is_single_word(ans_l_str):
            return ans_p_str.lower() == ans_l_str.lower()

    # # 顺序表达式特判
    lp = _as_list_from_chain(ans_l_str)
    pp = _as_list_from_chain(ans_p_str)
    if lp and pp:
        return pp == lp
    if lp and ',' in ans_p_str:
        return [x.strip() for x in ans_p_str.split(',')] == lp
    # print(ans_p_str, ans_l_str)

    if ans_p_str is None:
        return False
    if ans_p_str.replace(" ", "") == ans_l_str.replace(" ", ""):
        return True

    # 单个数字处理
    np1 = _extract_single_number(ans_p_str)
    nl1 = _extract_single_number(ans_l_str)
    # print("np1", np1)
    # print("nl1", nl1)
    if _looks_like_plain_number_or_with_unit(ans_p_str) and _looks_like_plain_number_or_with_unit(ans_l_str):
        try:
            if float(np1) == float(nl1):
                return True
        except Exception:
            pass

    # 时间类型数据特判 —— 支持“缺少 AM/PM 的模糊匹配（按小时 mod 12）”
    pr, lr = _parse_time_raw(ans_p_str), _parse_time_raw(ans_l_str)
    if pr and lr:
        hp, mp, sp, ap_p = pr
        hl, ml, sl, ap_l = lr
        # 分钟秒必须相同
        if mp == ml and sp == sl:
            if ap_p and ap_l:
                # 双方都有 AM/PM：走严格 24h 比较
                return _norm_time(ans_p_str) == _norm_time(ans_l_str)
            elif ap_p or ap_l:
                # 仅一侧有 AM/PM：按小时模 12 比较（模糊等价）
                return (hp % 12) == (hl % 12)
            else:
                # 双方都没有 AM/PM：小时必须相等
                return hp == hl

    # 无括号序列特判【只能处理两个长度都不小于2的情况，要是一个 1,2 一个 1，就交给后面的兜底走无序解析
    if ("," in ans_p_str and "," in ans_l_str
        and not _is_bracketed(ans_p_str)
        and not _is_bracketed(ans_l_str)):
        pp = _split_top_commas(ans_p_str)
        ll = _split_top_commas(ans_l_str)
        # print("pp", pp)
        # print("ll", ll)
        if len(pp) == len(ll):
            try:
                pp = [parse_latex_answer(clean_expr_str(str(x))) for x in pp]
                ll = [parse_latex_answer(clean_expr_str(str(x))) for x in ll]
                # 这里不能用Counter，会破坏数据，直接用双重循环
                for p in pp:
                    if p not in ll:
                        return False
                for l in ll:
                    if l not in pp:
                        return False
                return True
            except Exception:
                pass
        else:
            return False
    else:
        if _is_bracketed(ans_p_str) and not _is_bracketed(ans_l_str) and "," in ans_p_str and "," in ans_l_str:
            return False
        if not _is_bracketed(ans_p_str) and _is_bracketed(ans_l_str) and "," in ans_p_str and "," in ans_l_str:
            return False

    ans_p = parse_latex_answer(ans_p_str)
    if ans_p is None:
        return False
    ans_l = parse_latex_answer(ans_l_str)
    if ans_l is None:
        return False

    # 下标数据特判
    # if _is_numsub_sym(ans_p) or _is_numsub_sym(ans_l):
    #     # 两边都是数字下标符号：仅当“数值+下标”完全一致才相等
    #     if _is_numsub_sym(ans_p) and _is_numsub_sym(ans_l):
    #         return str(ans_p) == str(ans_l)
    #     # 一边是“数字下标”，另一边是普通数字/表达式：不相等
    #     return False

    # print("ans_p", ans_p)
    # print("ans_l", ans_l)
    # print("type(ans_p)", type(ans_p))
    # print("type(ans_l)", type(ans_l))

    if is_expr_equal(ans_p, ans_l, is_strict=is_strict):
        return True

    # TODO not suitable
    # ans_p_str = extract_answer_number(ans_p_str)
    # if is_number(ans_p_str):
    #     ans_p = parse_latex_answer(ans_p_str)
    #     if is_expr_equal(ans_p, ans_l, is_strict=is_strict):
    #         return True
    return False

# 投票
def vote(answers):
    counter = Counter(answers)
    return counter.most_common(1)[0][0]

# 判断字符串是否包含数字
def contains_number(s):
    return any(i.isdigit() for i in s)

# 粗略比较两个答案
def rough_compare_ans(generation, answer):
    for line in generation.split("\n")[::-1]:
        if contains_number(line):
            break
    words = line.split()
    for i, w in enumerate(words):
        if i > 0 and words[i - 1] in ["+", "-", "*", "/", "^"]:
            continue
        if i < len(words) - 1 and words[i + 1] in ["+", "-", "*", "/", "^"]:
            continue
        if not contains_number(w):
            continue
        if compare_numerical_ans(w.replace("$", ""), answer) and "=" not in " ".join(
            w[i:]
        ):
            return 1
    return 0

# matrix 测试
def load_cases_from_jsonl(jsonl_file):
    cases = []
    with open(jsonl_file, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                obj = json.loads(line)
                pred = obj["pred"]
                true = obj["true"]
                cases.append((pred, true))
    return cases

def main():
    # test_cases = [
    #     r"\begin{pmatrix}1&2\\3&4\end{pmatrix}",
    #     r"\begin{pmatrix*}[r]1&-2\\3&-4\end{pmatrix*}",
    #     r"\begin{bmatrix}1\\[2pt]2\\ 3\end{bmatrix}",
    #     r"\begin{smallmatrix}a&b\\ c&\dfrac{1}{xy}\end{smallmatrix}"
    # ]

    test_cases = [
        # (pred_answer, true_answer) 对
        (r"(\sqrt{2},\frac{3\pi}{4},-6)", r"\left( \sqrt{2}, \frac{7 \pi}{4}, -6 \right)", False), # false
        (r"(\frac{32}{69},\frac{49}{138},\frac{25}{138})", r"\left( \frac{11}{15}, \frac{11}{60}, \frac{1}{12} \right)", False), # false
        (r"(2,12)", r"(4,24)", False), # false
        (r"(-2,-1)", r"(4,5)", False), # false
        (r"\frac{\pi}{2}", r"\frac \pi 4 + 2 - \sqrt2", False), # false
        (r"(-1,2)", r"(2,5)", False), # false
        (r"(\frac{9}{8},\frac{29}{8})", r"(\tfrac{9}{8}, \tfrac{15}{8})", False), # false
        (r"(9,3)", r"(9, 3)", True), # true
        (r"[-1,1]", r"\left[ -\frac{1}{2}, \frac{1}{2} \right]", False), # false
        (r"[\frac{3}{2},3)", r"\left[ \frac{3}{2}, 2 \right)", False), # false
        (r"(\frac{-1+\sqrt{21}}{2},5)", r"(-\infty,4)", False), # false
        (r"f(2),f(1),f(4)", r"f(2) < f(1) < f(4)", True), # true
        (r"(5,-1,5,-1)", r"(4,1,4,0)", False), # false
        (r"1.85", r"1.85\text{USD}", True), # true
        (r"1:03 PM", r"1:03", True), # true
        (r"[-80,82]", r"-80\leq g(x)\leq 82", False), # false
        (r"5:00 am", r"05:00", True), # true
        (r"4:51:06", r"4:51:06 p.m.", True), # true
        (r"\text{A}", r"A", True), # true
        (r"3-i,3+i", r"3-i,-2+i", False), # false
        (r"\frac{1}{}", r"\frac{1}{(a - b)b}", False), # false
        (r"2-\frac{3\pi}{2}", r"2", False), # false
        (r"0,1,2,3,4", r"0, 1, 2, 4", False), # false
        (r"1,-1", r"-1, 1", True), # true
        (r"2,\frac{1}{2},3,\frac{1}{3}", r"2, 3, \frac{1}{2}, \frac{1}{3}", True), # true
        (r"\frac{7}{25},-1", r"\frac{7}{25}", False), # false
        (r"1,2,3,7", r"1, 6, 3 + \sqrt{2}, 3 - \sqrt{2}", False), # false
        (r"1,1,-\frac{1}{2},-2", r"1, 1, -2, -\frac{1}{2}", True), # true
        (r"2i,-2i", r"0,\tfrac32, 2i, -2i", False), # false
        (r"1+i,-1+i,-1-i,1-i", r"1 + i, 1 - i, -1 + i, -1 - i", True), # true
        (r"-\sqrt{2}", r"-\sqrt{2}, -\sqrt{2} + i, -\sqrt{2} - i ", False), # false
        (r"-3,\frac{3}{2}", r"-3", False), # false
        (r"\frac{-2}{3}", r"-\frac{2}{3}", True), # true
        (r"-\frac{}{3}", r"-\frac{a}{3}", False), # false
        (r"\frac{81\pi}{32}", r"\frac{81}{32}\pi", True), # true
        (r"\frac{-10}{9} ", r"-\dfrac{10}{9}", True), # true
        (r"2x", r"2\sec x", False), # false
        (r"12650", r"12,\!650", True), # true
        (r"-\frac{\sqrt{2}}{2}", r"-\frac{1}{\sqrt{2}}", True), # true
        (r"-5,6", r"(-5, 6)", False), # false
        (r"\sin32", r"\cos58", True), # true
        (r"\sqrt{2}+1", r"1+\sqrt{2}", True), # true
        (r"\sqrt{3},-\sqrt{3},1,-1", r"-\sqrt{3},-1,1,\sqrt{3}", True), # true
        (r"24  ", r"24_{10}", True), # false
        (r"-1, 7, -\frac{3}{2}", r"-\frac{3}{2}, -1, 7", True), # true
        (r"(-1,0)\cup(0,1] ", r"[-1,0) \cup (0,1]", False), # false
        (r"8,-8,1,-1", r"8,1,-1,-8", True), # true
        (r"1,-1,11,-11", r"-11, -1, 1, 11", True), # true
        (r"1+i,-1+i,-1-i,1-i", r"1 + i, 1 - i, -1 + i, -1 - i", True), # true
        (r"(5,\pi,-8)", r"(5,\pi,-8) ", True), # true
        (r" 0,-1,3,4", r"0,3,-1,4", True), # true
        (r" -n(2n+1)", r"-2n^2 - n", True),
        (r"(2,3,4)", r"(3,2,4)", False),
        (r"5:00 am",r"05:00", True), #时间判断这部分就不改了，虽然如果两个答案不都是时间格式会报错，但不会影响结果，也算是我们希望看到的
        (r"05\!:\!00", "05:00", True), # true
        (r"4\!:\!51\!:\!06 \text{ p.m.}", "4:51:06", True), # true
        (r"-(ab + ac + bc)", "-ab - ac - bc", True), # true
        (r"\mathbb{R}", r" (-\infty,\infty)", True), # true
        (r"f(X)=X+3", r" X + 3", True),
        (r"190-100\sqrt{3}", r"190-100 \cdot \sqrt{3}", True),
        (r"\sqrt{(\sqrt{36505} + 193}/2)", r"\\frac{\sqrt{2}}{2}", True),
        (r"\begin{pmatrix}2&0&7\\3&5&-1\\-8&-2&4\end{pmatrix}", r"\begin{pmatrix} 2 & 0 & 7 \\ 3 & 5 & -1\\-8 & -2 & 4 \end{pmatrix}", True),
        (r"\begin{pmatrix}2\-1\-5\end{pmatrix}", r"\begin{pmatrix} 2 \\ -1 \\ -5 \end{pmatrix}", True),
        (r"\begin{pmatrix}0&1&0\1&0&0\0&0&2\end{pmatrix}", r"\begin{pmatrix} 0 & 1 & 0 \\ 1 & 0 & 0 \\ 0 & 0 & 2 \end{pmatrix}", True),
        ( r" \begin{pmatrix}\frac{1}{3}\-\frac{1}{6}\\frac{1}{6}\end{pmatrix}", r"\begin{pmatrix}1/3\\-1/6\\1/6\end{pmatrix}",True),
        (r"\begin{pmatrix}-\frac{1}{2}\\frac{5}{2}\1\end{pmatrix}", r"\begin{pmatrix}-1/2\\5/2\\1\end{pmatrix}", True),
        (r"\begin{pmatrix}\frac{5}{2}\0\\frac{5}{2}\end{pmatrix}", r"\begin{pmatrix}5/2\\0\\5/2\end{pmatrix}", True),
        (r"\begin{pmatrix}-2-\frac{11\sqrt{10}}{5}\3-\frac{33\sqrt{10}}{5}\end{pmatrix}", r"\begin{pmatrix}1/5\\-18/5\end{pmatrix}", False),
        (r"\frac{1}{5}\begin{pmatrix}3&1\\3&1\end{pmatrix}", r"\begin{pmatrix} \frac{3}{5} & \frac{1}{5} \\ \frac{3}{5} & \frac{1}{5}\end{pmatrix}", True),
        (r"\frac{1}{13}\begin{pmatrix}4&-6\\-6&9\end{pmatrix}", r"\begin{pmatrix} 4/13 & -6/13 \\ -6/13 & 9/13 \end{pmatrix}", True),
        (r"\frac{1}{50}\begin{pmatrix}1&7\\7&49\end{pmatrix}", r"\begin{pmatrix} 1/50 & 7/50 \\ 7/50 & 49/50 \end{pmatrix}", True),
        (r" \begin{pmatrix}\frac{3}{2}\\-\frac{1}{2}\\2\end{pmatrix}", r" \begin{pmatrix}\frac{3}{2}\\-\frac{1}{2}\\2\end{pmatrix}", True),
        (r"(t+7)(t-7)", r"(t-7)(t+7)", True)
    ]

    # file_cases = load_cases_from_jsonl("/home/hj/BenchmarkCompression/Faclens/test/input.jsonl")

    # for pred, label, result in file_cases:
    #     pred = strip_string(pred)
    #     label = strip_string(label)
    #     # print(clean_expr_str(pred), clean_expr_str(label))
    #     # print(pred, label)
    #     # print(compare_ans(pred, label))
    #     # print("-"*100)
    #     if compare_ans(pred, label) == result:
    #         print("True", pred, label)
    #     else:
    #         print("False", pred, label)
    # print(compare_ans(pred, label)

    # print(file_cases)

    # with open("/home/hj/BenchmarkCompression/Faclens/test/results.txt", "w", encoding="utf-8") as f:
    #     for pred, label in file_cases:
    #         pred_clean = strip_string(pred)
    #         label_clean = strip_string(label)
    #         result = compare_ans(pred_clean, label_clean)

    #         line = f"pred: {pred_clean} | true: {label_clean} | result: {result}\n"
    #         print(line)
    #         f.write(line)

    for pred, label, true in test_cases:
        pred_clean = strip_string(pred)
        label_clean = strip_string(label)
        result = compare_ans(pred_clean, label_clean)
        if result == true:
            print("yes")
        # line = f"pred: {pred_clean} | true: {label_clean} | result: {result}\n"
        # print(line)


#################################################################
# 数据集答案解析测试
def scan_jsonl_for_parse_fail(
    jsonl_path: str,
    out_failed_path: str = "parse_failed.jsonl",
    id_keys: list[str] = ("id", "uid", "problem_id", "question_id", "qid"),
) -> dict:
    """
    扫描 JSONL 数据集，解析每条的 answer；
    收集所有“解析失败”的样本到 out_failed_path（JSONL）。
    返回统计信息。
    """
    p = Path(jsonl_path)
    if not p.exists():
        raise FileNotFoundError(f"JSONL not found: {p}")

    total = 0
    ok = 0
    failed = 0
    missing = 0

    fout = open(out_failed_path, "w", encoding="utf-8")

    def pick_id(obj, default):
        for k in id_keys:
            if k in obj:
                return obj[k]
        return default

    with p.open("r", encoding="utf-8") as f:
        for idx, line in enumerate(f, start=1):
            line = line.strip()
            if not line:
                continue
            total += 1
            try:
                obj = json.loads(line)
            except Exception:
                # 非法 JSON 行也算失败，记录下来
                failed += 1
                fout.write(json.dumps(
                    {"idx": idx, "id": None, "answer": None, "reason": "invalid_json", "raw": line},
                    ensure_ascii=False
                ) + "\n")
                continue

            if "answer" not in obj:
                missing += 1
                continue

            ans_raw = str(obj["answer"])
            # 先做百分号换算，再做你已有的清洗
            ans_txt = strip_string(clean_expr_str(ans_raw))

            parsed = None
            try:
                parsed = parse_latex_answer(ans_txt)  # ← 用你当前的解析函数名
            except Exception as e:
                parsed = None

            if parsed is None:
                failed += 1
                fout.write(json.dumps(
                    {
                        "idx": idx,
                        "id": pick_id(obj, idx),
                        "answer": ans_raw,
                        "answer_clean": ans_txt,
                        "reason": "parse_failed"
                    },
                    ensure_ascii=False
                ) + "\n")
            else:
                ok += 1

    fout.close()
    stats = {"total": total, "ok": ok, "failed": failed, "missing_answer_key": missing,
             "out_file": str(Path(out_failed_path).resolve())}
    print(f"[scan] total={total} ok={ok} failed={failed} missing={missing}")
    print(f"[scan] failed saved to: {stats['out_file']}")
    return stats

def scan_many_jsonl(jsonl_paths, out_dir="parse_failed",
                    id_keys=("id","uid","problem_id","question_id","qid")) -> dict:
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    totals = {"files": [], "total": 0, "ok": 0, "failed": 0, "missing_answer_key": 0}
    for p in jsonl_paths:
        p = Path(p).expanduser()
        if not p.exists():
            print(f"[scan] skip (not found): {p}")
            continue
        out_file = out_dir / f"{p.stem}.parse_failed.jsonl"
        stats = scan_jsonl_for_parse_fail(str(p), str(out_file), list(id_keys))
        stats["in_file"]  = str(p.resolve())
        totals["files"].append(stats)
        totals["total"]   += stats["total"]
        totals["ok"]      += stats["ok"]
        totals["failed"]  += stats["failed"]
        totals["missing_answer_key"] += stats["missing_answer_key"]

    print("\n[scan] ===== Summary over all files =====")
    print(f"[scan] total={totals['total']} ok={totals['ok']} "
          f"failed={totals['failed']} missing={totals['missing_answer_key']}")
    print(f"[scan] outputs in: {out_dir.resolve()}")
    return totals
#######################################################################

#################### 多选题筛选 ###################
def _as_single_letter(s: str) -> str | None:
    """
    若字符串在清洗后表示'单个英文字母'，返回其大写字母；否则返回 None。
    依赖你已有的 clean_expr_str / _strip_punct 等工具。
    """
    if not isinstance(s, str):
        return None
    t = clean_expr_str(s or "")
    t = (t or "").strip().strip("$").strip()
    t = _strip_punct(t)
    if len(t) == 1 and t.isalpha():
        return t.upper()
    return None


def filter_mcq_in_jsonl(src_path: str, out_path: str) -> dict:
    """
    从 src_path 读取 JSONL，筛出答案为单字母的样本，写入 out_path（JSONL）。
    返回统计信息：
    {in_file, out_file, total, with_answer, mcq_count, letter_hist}
    """
    src = Path(src_path)
    if not src.exists():
        raise FileNotFoundError(f"JSONL not found: {src}")

    out = Path(out_path)
    out.parent.mkdir(parents=True, exist_ok=True)

    total = 0
    with_answer = 0
    mcq_count = 0
    hist = Counter()

    with src.open("r", encoding="utf-8") as fin, out.open("w", encoding="utf-8") as fout:
        for line in fin:
            line = line.strip()
            if not line:
                continue
            total += 1
            try:
                obj = json.loads(line)
            except Exception:
                continue
            if "answer" not in obj:
                continue
            with_answer += 1
            letter = _as_single_letter(str(obj["answer"]))
            if letter:
                mcq_count += 1
                hist[letter] += 1
                # 附加字段，保留原始样本
                out_obj = dict(obj)
                out_obj["mcq_choice"] = letter
                fout.write(json.dumps(out_obj, ensure_ascii=False) + "\n")

    return {
        "in_file": str(src),
        "out_file": str(out),
        "total": total,
        "with_answer": with_answer,
        "mcq_count": mcq_count,
        "letter_hist": dict(hist),
    }


def dump_mcq_many_jsonl(jsonl_paths,
                        out_dir: str = "mcq_outputs",
                        merged_out: str | None = None) -> dict:
    """
    批量处理多个 JSONL：
      - 为每个文件导出一个 *_mcq.jsonl 到 out_dir
      - 如果 merged_out 不为空，则把所有 MCQ 追加写入该合并文件
    返回 summary 统计信息。
    """
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    results = []
    grand_total = 0
    grand_with_answer = 0
    grand_mcq = 0
    grand_hist = Counter()

    merged_writer = None
    merged_path = None
    if merged_out:
        merged_path = Path(merged_out)
        merged_path.parent.mkdir(parents=True, exist_ok=True)
        merged_writer = merged_path.open("w", encoding="utf-8")

    for p in jsonl_paths:
        p = Path(p).expanduser()
        if not p.exists():
            print(f"[mcq] skip (not found): {p}")
            continue

        out_file = out_dir / (p.stem + "_mcq.jsonl")
        # 先筛出并保存单文件
        stats = filter_mcq_in_jsonl(str(p), str(out_file))
        results.append(stats)

        # 如需合并，再把单文件的 MCQ 读入写到合并文件
        if merged_writer:
            with out_file.open("r", encoding="utf-8") as f:
                for line in f:
                    merged_writer.write(line)

        grand_total += stats["total"]
        grand_with_answer += stats["with_answer"]
        grand_mcq += stats["mcq_count"]
        grand_hist.update(stats["letter_hist"])

        print(f"[mcq] {stats['in_file']} -> {stats['out_file']}"
              f" | total={stats['total']} with_answer={stats['with_answer']}"
              f" mcq={stats['mcq_count']} hist={stats['letter_hist']}")

    if merged_writer:
        merged_writer.close()

    summary = {
        "files": results,
        "total": grand_total,
        "with_answer": grand_with_answer,
        "mcq_total": grand_mcq,
        "letter_hist": dict(grand_hist),
        "merged_out": str(merged_path) if merged_path else None,
        "out_dir": str(out_dir),
    }

    print("\n[mcq] ===== summary =====")
    print(f"[mcq] total={summary['total']} with_answer={summary['with_answer']} "
          f"mcq_total={summary['mcq_total']}")
    print(f"[mcq] histogram={summary['letter_hist']}")
    if merged_path:
        print(f"[mcq] merged file: {merged_path}")
    print(f"[mcq] per-file outputs in: {out_dir}")
    return summary

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    # 新增两个可选参数
    parser.add_argument("--scan-jsonl", nargs="*", help="要扫描的 JSONL 文件（可多个或逗号分隔）")
    parser.add_argument("--out-dir", type=str, default="parse_failed",
                        help="解析失败样本输出目录（默认 parse_failed/）")
    parser.add_argument("--id-keys", type=str,
                        default="id,uid,problem_id,question_id,qid",
                        help="样本ID字段优先级，逗号分隔")
    parser.add_argument("--dump-mcq",nargs="*",help="筛出 MCQ（答案为单个字母）的样本并保存。参数支持多个或逗号分隔")
    parser.add_argument("--mcq-outdir",type=str,default="mcq_outputs",help="每个文件的 MCQ 输出目录（默认 mcq_outputs）")
    parser.add_argument(
        "--mcq-merged",
        type=str,
        default=None,
        help="（可选）合并输出文件路径，比如 merged_mcq.jsonl"
    )
    args = parser.parse_args()

    # 若提供了 --scan-jsonl 就执行批量扫描，否则走你原主流程
    if args.scan_jsonl:
        # 展开逗号分隔写法：--scan-jsonl a.jsonl,b.jsonl c.jsonl
        files = []
        for item in args.scan_jsonl:
            files.extend([x for x in re.split(r"[,\s]+", item) if x])
        scan_many_jsonl(
            files,
            out_dir=args.out_dir,
            id_keys=[k.strip() for k in args.id_keys.split(",") if k.strip()]
        )
    else:
        if args.dump_mcq:
            files = []
            for item in args.dump_mcq:
                files.extend([x for x in re.split(r"[,\s]+", item) if x])
            dump_mcq_many_jsonl(
                files,
                out_dir=args.mcq_outdir,
                merged_out=args.mcq_merged
            )
            sys.exit(0)
        else:
            main()
