#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os, re, json, argparse, glob, math
from typing import Optional, List, Dict, Tuple
from collections import Counter

import transformers
from datasets import Dataset
from transformers.pipelines.pt_utils import KeyDataset

# ----------------------- CLI -----------------------
def parse_args():
    ap = argparse.ArgumentParser(
        description="Rejudge tiger JSON (merge 8 ranks) with identical logic as the PT version."
    )
    ap.add_argument("--dir", type=str, required=True, help="Directory with rank JSONs (e.g., .../steer_grid_runs)")
    # verifier / pipeline args
    ap.add_argument("--model", type=str, required=True, help="HF verifier model name (same as PT version)")
    ap.add_argument("--device", type=int, default=0)
    ap.add_argument("--batch-size", type=int, default=64)
    ap.add_argument("--max-new-tokens", type=int, default=64)
    ap.add_argument("--temperature", type=float, default=0.0)
    ap.add_argument("--top-p", type=float, default=1.0)
    ap.add_argument("--dtype", type=str, default="bfloat16", choices=["float16","bfloat16","float32"])
    ap.add_argument("--trust-remote-code", action="store_true")
    # tolerance / logging
    ap.add_argument("--rel-tol", type=float, default=1e-3, help="Relative tolerance (default: 0.1%)")
    ap.add_argument("--abs-tol", type=float, default=1e-6, help="Absolute tolerance")
    ap.add_argument("--verbose", action="store_true", help="Verbose logging to stdout")
    ap.add_argument("--save-log", type=str, default=None, help="Path to save detailed line logs (JSONL)")
    # update predicted_boxed like PT
    ap.add_argument("--update_pred_boxed", action="store_true", help="Refresh predicted_boxed per row")
    return ap.parse_args()

# ----------------------- Verifier prompt -----------------------
VERIFIER_PROMPT_TEMPLATE = (
    "User: ### Question: {question}\n\n"
    "### Ground Truth Answer: {ground_truth}\n\n"
    "### Student Answer: {student_answer}\n\n"
    "For the above question, please verify if the student's answer is equivalent to the ground truth answer.\n"
    "Do not solve the question by yourself; just check if the student's answer is equivalent to the ground truth answer.\n"
    "If the student's answer is correct, output \"Final Decision: Yes\". If the student's answer is incorrect, output \"Final Decision: No\". Assistant:"
)

# ----------------------- Extraction helpers -----------------------
_BOXED_RE = re.compile(r"\\boxed\s*\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}", re.DOTALL)
_CURRENCY_NUMBER_RE = re.compile(r"(?:[\$€£]\s*)?-?\d{1,3}(?:,\d{3})*(?:\.\d+)?|-?\d+(?:\.\d+)?")
_FINAL_PHRASES = [
    r"(?:Final\s*Answer(?:\s*is)?|Answer(?:\s*is)?|The\s*answer\s*is|Therefore,\s*the\s*answer\s*is|Thus,\s*the\s*answer\s*is)\s*[:\-]?\s*(.+?)\s*(?:$|\n|\.|\)|</s>|<\|im_end\|>|<\|endoftext\|>)",
    r"(?:correct\s*answer\s*(?:is)?|the\s*correct\s*answer\s*(?:is)?)\s*[:\-]?\s*([A-H])\b",
    r"(?:option|choice)\s*([A-Ha-h])\b",
]

_LATEX_WRAP_RE = re.compile(r"\\(?:text|mathrm|mbox|operatorname)\s*\{([^{}]*)\}")

def extract_last_boxed(text: str) -> Optional[str]:
    m = list(_BOXED_RE.finditer(text or ""))
    return m[-1].group(1).strip() if m else None

def extract_final_candidates(text: str) -> List[str]:
    t = (text or "").strip()
    out = []
    for pat in _FINAL_PHRASES:
        for m in re.finditer(pat, t, flags=re.IGNORECASE):
            out.append((m.group(1) if m.lastindex else m.group(0)).strip())
    return out

def extract_currency_number_candidates(text: str) -> List[str]:
    return [m.group(0).strip() for m in _CURRENCY_NUMBER_RE.finditer(text or "")]

def strip_latex_noise(s: str) -> str:
    if s is None:
        return ""
    t = str(s)
    for _ in range(3):
        t_new, n = _LATEX_WRAP_RE.subn(r"\1", t)
        t = t_new
        if n == 0:
            break
    t = t.replace("\\dfrac", "\\frac").replace("\\tfrac", "\\frac")
    t = t.replace("\\,", " ").replace("\\!", " ")
    t = t.replace("\\cdot", "*").replace("\\times", "×")
    t = t.strip()
    if t.startswith("$") and t.endswith("$") and len(t) >= 2:
        t = t[1:-1].strip()
    t = re.sub(r"\s+", " ", t).strip()
    return t

def predicted_boxed_from_generated_text(gen: str) -> List[str]:
    """
    Priority: Final Answer phrases > last \boxed{} > last numeric token.
    """
    cands: List[str] = []
    finals = extract_final_candidates(gen)
    if finals:
        cands.extend(finals)
    b = extract_last_boxed(gen)
    if b:
        cands.append(b)
    nums = extract_currency_number_candidates(gen)
    if nums:
        cands.append(nums[-1])

    out, seen = [], set()
    for x in cands:
        t = strip_latex_noise(x).replace("**", "").strip()
        if t and t not in seen:
            out.append(t)
            seen.add(t)
    return out

def parse_verdict(s: str) -> Optional[bool]:
    t = (s or "").strip()
    if re.search(r"\bFinal\s*Decision\s*:\s*Yes\b", t, flags=re.IGNORECASE):
        return True
    if re.search(r"\bFinal\s*Decision\s*:\s*No\b", t, flags=re.IGNORECASE):
        return False
    return None

def dtype_of(name: str):
    return {"float16": "float16", "bfloat16": "bfloat16", "float32": "float32"}[name]

# ----------------------- Numeric equivalence -----------------------
_NUM_TOKEN_RE = re.compile(r"[+-]?\d+(?:,\d{3})*(?:\.\d+)?(?:[eE][+-]?\d+)?")
_SUPER_MAP = str.maketrans({"⁰":"0","¹":"1","²":"2","³":"3","⁴":"4","⁵":"5","⁶":"6","⁷":"7","⁸":"8","⁹":"9","⁺":"+","⁻":"-"})

def _normalize_scientific_unicode(t: str) -> str:
    def repl(m):
        exp = m.group(1).translate(_SUPER_MAP)
        return f"10^{{{exp}}}"
    t = re.sub(r"10([⁰¹²³⁴⁵⁶⁷⁸⁹⁺⁻]+)", repl, t)
    t = t.translate(_SUPER_MAP)
    return t

def _to_float_canonical(s: str) -> Optional[float]:
    if not s:
        return None
    t = s.strip()
    neg = t.startswith("(") and t.endswith(")")
    if neg:
        t = t[1:-1].strip()
    t = re.sub(r"^[\$\€\£]\s*", "", t)
    t = t.rstrip("%").strip()
    t = t.replace(",", "")
    t = _normalize_scientific_unicode(t)
    t = t.replace("×", "x")
    # a x 10^{b}
    t = re.sub(r"([+-]?\d+(?:\.\d+)?)\s*x\s*10\^\{([+-]?\d+)\}", lambda m: f"{float(m.group(1))}e{int(m.group(2))}", t)
    t = re.sub(r"([+-]?\d+(?:\.\d+)?)\s*x\s*10\^([+-]?\d+)",    lambda m: f"{float(m.group(1))}e{int(m.group(2))}", t)

    matches = list(_NUM_TOKEN_RE.finditer(t))
    if not matches:
        return None

    def is_unit_exponent(m):
        i = m.start(); j = i - 1
        while j >= 0 and t[j].isspace(): j -= 1
        return j >= 0 and t[j] == "^"

    cand = [m for m in matches if not is_unit_exponent(m)] or matches

    def score(m):
        s = m.group(0)
        return (("e" in s.lower()) or ("." in s), len(s))

    cand.sort(key=score, reverse=True)
    token = cand[0].group(0)
    try:
        v = float(token)
        return -v if neg else v
    except Exception:
        return None

def has_direction_in_answer(answer: str) -> bool:
    keywords = ["+x","positive x","−x","-x","negative x","north","south","east","west",
                "left","right","upward","downward","forward","backward"]
    al = (answer or "").lower()
    return any(kw in al for kw in keywords)

def check_direction_match(student: str, ground: str) -> Optional[bool]:
    if not (has_direction_in_answer(ground) and has_direction_in_answer(student)):
        return None
    g = ground.lower(); s = student.lower()
    negative_words = ["-x","−x","negative x","negative","south","west","left","downward"]
    positive_words = ["+x","positive x","positive","north","east","right","upward"]
    want_neg = any(w in g for w in negative_words)
    has_neg  = any(w in s for w in negative_words)
    want_pos = any(w in g for w in positive_words)
    has_pos  = any(w in s for w in positive_words)
    if want_neg and not has_neg: return False
    if want_pos and not has_pos: return False
    return None

def cheap_equivalence(student: str, ground: str, rel_tol: float = 1e-3, abs_tol: float = 1e-6) -> Optional[bool]:
    dir_check = check_direction_match(student, ground)
    if dir_check is False:
        return False
    vg, vs = _to_float_canonical(ground), _to_float_canonical(student)
    if vg is None or vs is None:
        return None
    if abs(vg) <= abs_tol:
        return abs(vs) <= abs_tol
    return abs(vs - vg) / abs(vg) <= rel_tol

# ----------------------- Logger -----------------------
class Logger:
    def __init__(self, verbose: bool = False, log_file: Optional[str] = None):
        self.verbose = verbose
        self.log_file = log_file
        self.entries: List = []
    def log(self, msg: str, level: str = "INFO"):
        entry = f"[{level}] {msg}"
        if self.verbose:
            print(entry)
        self.entries.append(entry)
    def log_sample(self, idx: int, info: Dict):
        self.entries.append({"idx": idx, **info})
        if self.verbose:
            print(f"[SAMPLE {idx}] {info.get('summary','')}")
    def save(self):
        if self.log_file:
            with open(self.log_file, "w") as f:
                for e in self.entries:
                    f.write(json.dumps(e) + "\n" if isinstance(e, dict) else str(e) + "\n")

# ----------------------- JSON grouping -----------------------
_RANK_FILE_RE = re.compile(r"""^(?P<base>.+?)\.rank(?P<r>\d+)\.json$""")

def find_groups(dir_path: str) -> Dict[str, List[str]]:
    files = [f for f in glob.glob(os.path.join(dir_path, "*.json"))
             if os.path.isfile(f) and not os.path.basename(f).startswith("grid_all_modes")]
    groups: Dict[str, List[str]] = {}
    for fp in files:
        m = _RANK_FILE_RE.match(os.path.basename(fp))
        if not m:
            continue
        base = m.group("base")
        groups.setdefault(base, []).append(fp)
    for base, lst in groups.items():
        lst.sort(key=lambda p: int(_RANK_FILE_RE.match(os.path.basename(p)).group("r")))
    return groups

# ----------------------- Field helpers -----------------------
def get_question(row: Dict) -> str:
    return row.get("question") or row.get("prompt") or row.get("q") or ""

def get_ground(row: Dict) -> str:
    return (row.get("gold") or
            row.get("ground_truth") or
            row.get("ground_truth_answer") or
            row.get("answer") or "")

def get_generated_text(row: Dict) -> str:
    return (row.get("generated_text") or
            row.get("pred") or
            row.get("student_answer") or
            row.get("response") or "")

# ----------------------- Main per-group rejudge -----------------------
def rejudge_group(base: str, files: List[str], args, pipe, tok, logger: Logger) -> Optional[str]:
    merged_rows: List[Dict] = []
    meta = {"mode_base": base, "n_files": len(files), "files": [os.path.basename(x) for x in files]}
    if not files:
        return None

    # merge rows
    for fp in files:
        try:
            with open(fp, "r", encoding="utf-8") as f:
                obj = json.load(f)
            rows = obj.get("rows", [])
            if isinstance(rows, list):
                merged_rows.extend(rows)
        except Exception as e:
            logger.log(f"Failed to load {fp}: {e}", "ERROR")

    total = len(merged_rows)
    if total == 0:
        logger.log(f"[{base}] No rows after merging.", "WARN")

    # pass 1: cheap equivalence / candidate selection
    prompts: List[str] = []
    regen_indices: List[int] = []
    verdicts: List[Optional[bool]] = [None] * total
    stats = Counter()

    for idx, row in enumerate(merged_rows):
        q  = get_question(row)
        gt = get_ground(row)
        gen = get_generated_text(row)

        # candidates with improved priority
        cands = predicted_boxed_from_generated_text(gen)
        selected = cands[0] if cands else "No Answer"

        if args.update_pred_boxed:
            row["predicted_boxed"] = [selected] if selected else []

        if len(cands) > 1:
            stats["multi_candidate_used"] += 1

        found_match = False
        for cand in cands:
            heur = cheap_equivalence(cand, gt, rel_tol=args.rel_tol, abs_tol=args.abs_tol)
            if heur is True:
                verdicts[idx] = True
                stats["cheap_correct"] += 1
                selected = cand
                found_match = True
                break
            elif heur is False:
                # keep searching next candidate
                pass

        if found_match:
            row["predicted_boxed"] = [selected] if args.update_pred_boxed else row.get("predicted_boxed", [selected])
            continue

        # all candidates confidently wrong?
        if cands and all(cheap_equivalence(c, gt, rel_tol=args.rel_tol, abs_tol=args.abs_tol) is False for c in cands):
            verdicts[idx] = False
            stats["cheap_incorrect"] += 1
            continue

        # unknown -> LLM verification
        stats["cheap_unknown"] += 1
        user_text = VERIFIER_PROMPT_TEMPLATE.format(question=q, ground_truth=gt, student_answer=selected)
        messages = [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": user_text},
        ]
        prompt = tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        prompts.append(prompt)
        regen_indices.append(idx)

    # pass 2: LLM verify unknowns
    if prompts:
        logger.log(f"[{base}] Verifying {len(prompts)} items via LLM", "INFO")
        ds = Dataset.from_dict({"prompt": prompts})
        out_iter = pipe(
            KeyDataset(ds, "prompt"),
            batch_size=args.batch_size,
            max_new_tokens=args.max_new_tokens,
            do_sample=(args.temperature > 0.0),
            temperature=max(1e-5, args.temperature) if args.temperature > 0.0 else None,
            top_p=args.top_p if args.temperature > 0.0 else None,
            pad_token_id=tok.eos_token_id or tok.pad_token_id,
            num_return_sequences=1,
            return_full_text=False,
        )
        k = 0
        for out in out_iter:
            text = ""
            if isinstance(out, dict):
                gt = out.get("generated_text")
                if isinstance(gt, list) and gt:
                    last = gt[-1]
                    text = last.get("content", "") if isinstance(last, dict) else str(last)
                else:
                    text = str(gt)
            else:
                text = str(out)
            v = parse_verdict(text)
            idx = regen_indices[k]
            verdicts[idx] = v
            stats["llm_verified"] += 1
            k += 1

    # pass 3: write back correctness & compute accuracy
    correct = 0
    for idx, row in enumerate(merged_rows):
        v = verdicts[idx]
        new_ok = (v is True)
        row["is_correct"] = bool(new_ok)
        if "ok" in row:
            row["ok"] = bool(new_ok)
        if new_ok:
            correct += 1

    acc = correct / max(1, total)
    out_obj = {
        **meta,
        "rows": merged_rows,
        "total": total,
        "correct": correct,
        "accuracy": acc,
        "rel_tol": args.rel_tol,
        "abs_tol": args.abs_tol,
    }

    # save
    out_full = os.path.join(os.path.dirname(files[0]), os.path.basename(base) + ".fixeval.json")
    with open(out_full, "w", encoding="utf-8") as f:
        json.dump(out_obj, f, ensure_ascii=False, indent=2)

    logger.log(f"[{base}] Saved: {out_full} | rows={total} correct={correct} acc={acc:.4f}")
    return out_full

# ----------------------- Main -----------------------
def main():
    args = parse_args()
    logger = Logger(verbose=args.verbose, log_file=args.save_log)

    dir_path = os.path.abspath(args.dir)
    if not os.path.isdir(dir_path):
        raise SystemExit(f"[ERROR] Not a directory: {dir_path}")

    # tokenizer & pipeline (mirror PT)
    tok = transformers.AutoTokenizer.from_pretrained(
        args.model, trust_remote_code=args.trust_remote_code
    )
    tok.padding_side = "left"
    if tok.pad_token_id is None:
        tok.pad_token_id = tok.eos_token_id

    pipe = transformers.pipeline(
        "text-generation",
        model=args.model,
        tokenizer=tok,
        device=args.device,
        trust_remote_code=args.trust_remote_code,
        model_kwargs={"torch_dtype": dtype_of(args.dtype)},
    )

    print("\n" + "="*70)
    print(" TIGER JSON REJUDGE (MERGED)")
    print("="*70)
    print(f" Scan dir      : {dir_path}")
    print(f" Verifier model: {args.model}")
    print(f" Device        : {args.device}")
    print(f" rel_tol       : {args.rel_tol}")
    print(f" abs_tol       : {args.abs_tol}")
    print(f" update_boxed  : {args.update_pred_boxed}")
    print("="*70 + "\n")

    groups = find_groups(dir_path)
    if not groups:
        print("[ERROR] No rank JSON groups found (pattern: *.rankX.json; grid_all_* ignored).")
        return

    print(f"Found {len(groups)} mode group(s):")
    for base, lst in groups.items():
        print(f"  - {os.path.basename(base)}: {len(lst)} file(s)")
    print("")

    produced = []
    for base, files in groups.items():
        print("-"*60)
        print(f"[GROUP] {os.path.basename(base)}")
        outp = rejudge_group(base, files, args, pipe, tok, logger)
        if outp:
            produced.append(outp)

    print("")
    print("="*70)
    print(" ACCURACY SUMMARY")
    print("="*70)
    for fp in produced:
        try:
            with open(fp, "r", encoding="utf-8") as f:
                obj = json.load(f)
            acc = obj.get("accuracy", 0.0)
            total = obj.get("total", 0)
            correct = obj.get("correct", 0)
            print(f"{os.path.basename(fp):45s}  acc={acc:.4%}  ({correct}/{total})")
        except Exception as e:
            print(f"{os.path.basename(fp):45s}  (failed to read: {e})")
    print("="*70)
    print("")

    logger.save()
    print("\n" + "="*70)
    print(" COMPLETED")
    print("="*70 + "\n")

if __name__ == "__main__":
    main()
