import argparse, os, re, json, math, glob
from typing import List, Tuple, Optional, Dict
import fractions as _fra

# ======================
# Balanced-brace helpers (same as code1)
# ======================
def _balanced_from(s: str, open_brace_idx: int) -> Tuple[Optional[str], Optional[int]]:
    """Extract content from balanced braces starting at open_brace_idx."""
    if open_brace_idx < 0 or open_brace_idx >= len(s) or s[open_brace_idx] != "{":
        return None, None
    depth, i, start = 1, open_brace_idx + 1, open_brace_idx + 1
    while i < len(s):
        ch = s[i]
        if ch == "{":
            depth += 1
        elif ch == "}":
            depth -= 1
            if depth == 0:
                return s[start:i], i + 1
        i += 1
    return None, None

def extract_boxed_all(s: str) -> List[str]:
    """Extract all \\boxed{...} contents from string."""
    out = []
    for m in re.finditer(r"\\boxed\s*\{", s):
        brace_pos = m.end() - 1
        content, nxt = _balanced_from(s, brace_pos)
        if content is None:
            break
        out.append(content.strip())
    return out

def extract_last_boxed(s: str) -> Optional[str]:
    """Extract the last \\boxed{...} content from string."""
    xs = extract_boxed_all(s)
    return xs[-1] if xs else None

def _strip_boxed_balanced(s: str) -> str:
    """Remove all \\boxed{...} wrappers while preserving inner content."""
    parts, last = [], 0
    for m in re.finditer(r"\\boxed\s*\{", s):
        parts.append(s[last:m.start()])
        brace_pos = m.end() - 1
        content, nxt = _balanced_from(s, brace_pos)
        if content is None:
            parts.append(s[m.start():])
            last = len(s)
            break
        parts.append(content)
        last = nxt
    parts.append(s[last:])
    return "".join(parts)

# ======================
# Text/number normalization (same as code1)
# ======================
_SPACING_CMDS = [r"\\,", r"\\!", r"\\;", r"\\:", r"\\quad", r"\\qquad"]

def _ascii_minus(s: str) -> str:
    """Convert Unicode minus signs to ASCII hyphen-minus."""
    s = "" if s is None else str(s)
    return s.translate({ord(u): "-" for u in "\u2212\u2012\u2013\u2014"})

def _normalize_fracs(s: str) -> str:
    """Normalize \\frac, \\dfrac, \\tfrac to (numerator)/(denominator) format."""
    s = re.sub(r"\\(?:d|t)?frac\s*\{\s*([^{}]+?)\s*\}\s*\{\s*([^{}]+?)\s*\}", r"(\1)/(\2)", s)
    s = re.sub(r"\\(?:d|t)?frac\s+([^\s{}]+)\s+([^\s{}]+)", r"(\1)/(\2)", s)
    s = re.sub(r"\\(?:d|t)?frac\s*([0-9]+)\s*([0-9]+)", r"(\1)/(\2)", s)
    return s

_ENV_ANY_OPEN  = re.compile(r"\\begin\{[^}]+\}")
_ENV_ANY_CLOSE = re.compile(r"\\end\{[^}]+\}")
_ROW_SEP = re.compile(r"(?:\\\\+|\\cr|;|&)+")  # LaTeX row/col separators

def _strip_latex_wrappers(s: str) -> str:
    """Strip LaTeX commands and environments to get raw content."""
    s = _strip_boxed_balanced(s)

    # remove common spacing commands
    for pat in _SPACING_CMDS:
        s = re.sub(pat, "", s)

    # remove \left \right
    s = s.replace(r"\left", "").replace(r"\right", "")

    # normalize \frac*, \dfrac*, \tfrac* variants
    s = _normalize_fracs(s)

    # degree -> " deg"
    s = re.sub(r"\^\s*\\?(?:circ|degree)s?", " deg", s)

    # remove ANY \begin{...} / \end{...}
    s = _ENV_ANY_OPEN.sub("", s)
    s = _ENV_ANY_CLOSE.sub("", s)

    # convert LaTeX row/col separators to commas, then squash
    s = _ROW_SEP.sub(",", s)
    s = re.sub(r",\s*,+", ",", s)  # collapse repeated commas from \\\\

    # strip \( \) \[ \] and dollar math markers
    s = re.sub(r"\\[()\[\]]", "", s)
    s = s.replace("$", "")

    # unwrap \text{}, \mathrm{}, \operatorname{}
    s = re.sub(r"\\(?:text|mathrm|operatorname)\{([^}]*)\}", r"\1", s)

    # common macros
    s = s.replace(r"\infty", "infty").replace(r"\pi", "pi").replace(r"\$", "")

    # collapse whitespace
    s = re.sub(r"\s+", " ", s).strip()
    return s

def _after_hashes_line(s: str) -> Optional[str]:
    """Extract text after #### marker (common ground truth format)."""
    i = s.rfind("####")
    return None if i == -1 else s[i+4:].strip()

def extract_pred_answer(t: str) -> str:
    """Extract predicted answer from model output."""
    bx = extract_last_boxed(t)
    if bx and bx.strip() != "...":
        return bx.strip()
    m = re.search(r"(?im)^\s*final\s*answer\s*:\s*(.+?)\s*$", t)
    if m:
        return m.group(1).strip()
    lines = [ln.strip() for ln in t.splitlines() if ln.strip()]
    return lines[-1] if lines else t.strip()

def extract_gold_answer(t: str) -> str:
    """Extract gold answer from ground truth text."""
    tail = _after_hashes_line(t)
    return tail if tail is not None else extract_pred_answer(t)

# ======================
# Exact numeric parsing (scalars & vectors) — same as code1
# ======================
_NUM_TOKEN = r"(?:\d+(?:\.\d+)?|\.\d+)"
_NUM_ATOM  = re.compile(rf"""^\s*[-+]?{_NUM_TOKEN}(?:\s*/\s*{_NUM_TOKEN})?\s*$""")
_FRAC_TOKEN = re.compile(rf"""^\s*(?P<num>[-+]?{_NUM_TOKEN})\s*/\s*(?P<den>[-+]?{_NUM_TOKEN})\s*$""")
_NUM_FALLBACK = re.compile(rf"[-+]?{_NUM_TOKEN}(?:\s*/\s*{_NUM_TOKEN})?")

def _parse_rational_exact(s: str) -> Optional[_fra.Fraction]:
    """Parse a fraction string to exact Fraction."""
    s = _ascii_minus(_strip_latex_wrappers(s)).strip()
    s = s.replace(",", "")  # thousands sep (scalar only)
    s = re.sub(r"^\((.*)\)$", r"\1", s)
    m = _FRAC_TOKEN.match(s)
    if not m:
        return None
    try:
        a = _fra.Fraction(m.group("num"))
        b = _fra.Fraction(m.group("den"))
        if b == 0:
            return None
        return a / b
    except Exception:
        return None

def _parse_number_exact(s: str) -> Optional[_fra.Fraction]:
    """Parse a number string to exact Fraction."""
    s = _ascii_minus(_strip_latex_wrappers(s)).strip()
    s = s.replace(",", "")  # thousands sep (scalar only)
    s = re.sub(r"^\((.*)\)$", r"\1", s)
    if not _NUM_ATOM.match(s):
        m = _NUM_FALLBACK.search(s)
        if not m: return None
        s = m.group(0)
        if not _NUM_ATOM.match(s): return None
    if "/" in s:
        try:
            a, b = [t.strip() for t in s.split("/", 1)]
            return _fra.Fraction(a) / _fra.Fraction(b)
        except Exception:
            return None
    try:
        return _fra.Fraction(s)
    except Exception:
        return None

_SEQ_ENV_HINTS = ("\\\\", "\\cr", "\\begin", "\\end", "&", ";")

def _looks_like_sequence(s: str) -> bool:
    """Check if string looks like a vector/sequence."""
    # If it has LaTeX row/col env hints, it's a sequence/matrix.
    if any(h in s for h in _SEQ_ENV_HINTS):
        return True
    # After light normalize, detect bracketed comma-list: (a, b, c) or [a, b]
    t = _ascii_minus(_strip_latex_wrappers(s)).strip()
    return bool(re.match(r"^\s*[\(\[\{]\s*[^,]+,\s*.+[\)\]\}]\s*$", t))

def _parse_vector_exact(s: str) -> Optional[List[_fra.Fraction]]:
    """Parse a vector string to list of exact Fractions."""
    t = _ascii_minus(_strip_latex_wrappers(s))
    t = t.strip()
    # strip one outer wrapper like ( ... ) or [ ... ] or { ... }
    t = re.sub(r"^[\(\[\{]\s*(.*)\s*[\)\]\}]$", r"\1", t)
    # split on commas produced from row/col separators
    if "," in t:
        parts = [p.strip() for p in t.split(",")]
    else:
        parts = re.split(r"\s+", t.strip())
    parts = [p for p in parts if p]  # drop empties
    if not parts:
        return None
    out: List[_fra.Fraction] = []
    for p in parts:
        v = _parse_number_exact(p)
        if v is None:
            return None
        out.append(v)
    return out

def _vectors_equal_exact(a: str, b: str) -> bool:
    """Check if two strings represent equal vectors."""
    # Only treat as vectors if at least one side looks like a sequence
    if not (_looks_like_sequence(a) or _looks_like_sequence(b)):
        return False
    va = _parse_vector_exact(a)
    vb = _parse_vector_exact(b)
    return (va is not None) and (vb is not None) and (len(va) == len(vb)) and all(x == y for x, y in zip(va, vb))

# ======================
# Choice & canon helpers (same as code1)
# ======================
_CHOICE_RE = re.compile(r"(?i)^\s*\(?\s*([A-E])\s*\)?\s*$")

def _choice_letter(s: str) -> Optional[str]:
    """Extract multiple choice letter (A-E)."""
    s = _strip_latex_wrappers(s)
    m = _CHOICE_RE.match(s)
    return m.group(1).upper() if m else None

def canon(s: str) -> str:
    """Canonicalize string for text comparison."""
    s = _ascii_minus(_strip_latex_wrappers(s)).strip()
    s = re.sub(r"(?i)^\s*(?:ans(?:wer)?)\s*[:=]\s*", "", s)
    s = re.sub(r"^\((.*)\)$", r"\1", s)
    s = re.sub(r"\s+", " ", s)
    return s.lower()

# ======================
# Compare (same as code1)
# ======================
def compare_answers(pred_text: str, gold_text: str, atol: float = 1e-6) -> bool:
    """
    Compare predicted and gold answers with multiple strategies:
    1. Exact vector/matrix comparison
    2. Exact rational scalar comparison
    3. Numeric scalar comparison (with tolerance)
    4. Multiple choice letter comparison
    5. Canonical text equality
    """
    pred_raw = extract_pred_answer(pred_text)
    gold_raw = extract_gold_answer(gold_text)

    # 1) Exact vector/matrix path
    if _vectors_equal_exact(pred_raw, gold_raw):
        return True

    # 2) Exact rational scalar path
    pr = _parse_rational_exact(pred_raw)
    gr = _parse_rational_exact(gold_raw)
    if (pr is not None) and (gr is not None):
        return pr == gr

    # 3) Numeric scalar path (exact Fractions first, then tolerant floats)
    pnum = _parse_number_exact(pred_raw)
    gnum = _parse_number_exact(gold_raw)
    if (pnum is not None) and (gnum is not None):
        if pnum == gnum:
            return True
        return math.isclose(float(pnum), float(gnum), rel_tol=0.0, abs_tol=atol)

    # 4) Multiple choice letter path
    pc = _choice_letter(pred_raw)
    gc = _choice_letter(gold_raw)
    if pc and gc:
        return pc == gc

    # 5) Canonical text equality
    return canon(pred_raw) == canon(gold_raw)

# ======================
# Merge & Rejudge
# ======================
_RANK_FILE_RE = re.compile(r"""^(?P<base>.+?)\.rank(?P<r>\d+)\.json$""")

def find_groups(dir_path: str) -> Dict[str, List[str]]:
    """Group files by base (strip '.rankX.json'), ignore grid_all_modes.*"""
    files = [f for f in glob.glob(os.path.join(dir_path, "*.json"))
             if os.path.isfile(f) and not os.path.basename(f).startswith("grid_all_modes")]
    groups: Dict[str, List[str]] = {}
    for fp in files:
        name = os.path.basename(fp)
        m = _RANK_FILE_RE.match(name)
        if not m:  # skip non-rank files
            continue
        base = m.group("base")  # like 'a0.1__linear__L27.soft_argmax'
        groups.setdefault(base, []).append(fp)
    # sort each group's files by rank id
    for base, lst in groups.items():
        lst.sort(key=lambda p: int(_RANK_FILE_RE.match(os.path.basename(p)).group("r")))
    return groups

def rejudge_group(base: str, files: List[str], atol: float, update_pred_boxed: bool) -> Optional[str]:
    """Merge rows from all rank files, rejudge, and write <base>.fixeval.json"""
    merged_rows = []
    meta = {
        "mode_base": base,
        "n_files": len(files),
        "files": [os.path.basename(x) for x in files],
    }
    if not files:
        return None

    for fp in files:
        try:
            with open(fp, "r", encoding="utf-8") as f:
                obj = json.load(f)
        except Exception as e:
            print(f"[ERROR] Failed to load {fp}: {e}")
            continue
        rows = obj.get("rows", [])
        if isinstance(rows, list):
            merged_rows.extend(rows)

    total = 0
    changed = 0
    correct = 0

    for row in merged_rows:
        if not isinstance(row, dict):
            continue
        gold = row.get("gold", "")
        pred_text = row.get("pred", "")
        ok = compare_answers(pred_text, gold, atol=atol)
        old_ok = row.get("ok", None)
        if old_ok is not None and bool(old_ok) != bool(ok):
            changed += 1
        row["ok"] = bool(ok)
        if update_pred_boxed:
            last_box = extract_last_boxed(pred_text)
            row["predicted_boxed"] = [last_box] if last_box is not None else []
        total += 1
        correct += int(ok)

    accuracy = correct / max(1, total)
    out_obj = {
        **meta,
        "rows": merged_rows,
        "total": total,
        "correct": correct,
        "accuracy": accuracy,
        "atol": atol,
        "update_pred_boxed": bool(update_pred_boxed),
    }

    out_dir = os.path.dirname(files[0])
    out_full = os.path.join(out_dir, os.path.basename(base) + ".fixeval.json")
    try:
        with open(out_full, "w", encoding="utf-8") as f:
            json.dump(out_obj, f, ensure_ascii=False, indent=2)
        print(f"✓ Saved fixeval: {out_full}  (rows={total}, correct={correct}, acc={accuracy:.4f})")
        return out_full
    except Exception as e:
        print(f"[ERROR] Failed to save {out_full}: {e}")
        return None

def main():
    ap = argparse.ArgumentParser(
        description="Merge 8 rank JSONs per mode, then rejudge with the same logic as code1 (.pt fixeval).",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    ap.add_argument("--dir", required=True, help="Directory containing steer rank JSONs (e.g., .../steer_grid_runs)")
    ap.add_argument("--atol", type=float, default=1e-6, help="Absolute tolerance for numeric comparison")
    ap.add_argument("--update_pred_boxed", action="store_true",
                    help="Also refresh row['predicted_boxed'] from the generated text")
    args = ap.parse_args()

    dir_path = os.path.abspath(args.dir)
    if not os.path.isdir(dir_path):
        print(f"[ERROR] Not a directory: {dir_path}")
        return

    print("\n" + "="*60)
    print("FIXEVAL (JSON MERGED): Math Evaluation Repair Tool")
    print("="*60)
    print(f"Scan dir: {dir_path}")
    print(f"ATOL:     {args.atol}")
    print(f"Update predicted_boxed: {args.update_pred_boxed}")
    print("="*60 + "\n")

    groups = find_groups(dir_path)
    if not groups:
        print("[ERROR] No rank JSON groups found.")
        return

    print(f"Found {len(groups)} mode group(s):")
    for base, lst in groups.items():
        print(f"  - {os.path.basename(base)}: {len(lst)} file(s)")

    print("")
    for base, files in groups.items():
        print("-"*60)
        print(f"[GROUP] {os.path.basename(base)}  ({len(files)} file(s))")
        rejudge_group(base, files, atol=args.atol, update_pred_boxed=args.update_pred_boxed)

    print("\n" + "="*60)
    print("COMPLETED")
    print("="*60 + "\n")

if __name__ == "__main__":
    main()
