#!/usr/bin/env python3
"""
fix_gsm8k_v1_question_from_blocks.py

From a directory of v1 JSON files ({id}_v1.json), check that the "question"
field matches the original GSM8K dataset (JSONL or JSON).

Behavior:
- If v1.question matches the original (under normalization) -> do NOT write anything.
- Else:
  * Compute edit distance between v1.question and original.question.
  * If distance > threshold:
        -> replace v1["question"] with ORIGINAL raw question, then write corrected v1 as {id}_v1.json into --output.
    Else:
        -> ignore mismatch and do NOT write anything.
- Always write a text report of all mismatches or missing-original cases with:
    * edit distance
    * unified diff (normalized)
    * ACTION taken per item
"""

import argparse
import json
import re
import difflib
from pathlib import Path
from typing import Dict, Tuple, Optional, List, Any

WS_RE = re.compile(r"\s+")


def norm_text(s: Optional[str], *, case_insensitive: bool) -> str:
    if s is None:
        return ""
    s = s.strip()
    s = WS_RE.sub(" ", s)
    if case_insensitive:
        s = s.lower()
    return s


def load_original_question_map(
    path: Path,
    case_insensitive: bool
) -> Tuple[Dict[int, str], Dict[int, str]]:
    """
    Load original JSON/JSONL and return two maps:
        norm_q_map[id] -> question_norm
        raw_q_map[id]  -> question_raw
    """
    objs: List[Any] = []

    if path.suffix.lower() == ".jsonl":
        with path.open("r", encoding="utf-8") as f:
            for idx, line in enumerate(f):
                line = line.strip()
                if not line:
                    continue
                try:
                    obj = json.loads(line)
                except Exception:
                    continue
                objs.append(obj)
    elif path.suffix.lower() == ".json":
        with path.open("r", encoding="utf-8") as f:
            data = json.load(f)
            if not isinstance(data, list):
                raise ValueError(f"Expected a list in {path}, got {type(data)}")
            objs = data
    else:
        raise ValueError(f"Unsupported file extension for --original: {path.suffix}")

    norm_q_map: Dict[int, str] = {}
    raw_q_map: Dict[int, str] = {}

    for idx, obj in enumerate(objs):
        if not isinstance(obj, dict):
            continue
        obj_id = obj.get("id")
        key = obj_id if isinstance(obj_id, int) else idx

        q_raw = obj.get("question", "")
        q_norm = norm_text(q_raw, case_insensitive=case_insensitive)

        norm_q_map[key] = q_norm
        raw_q_map[key] = q_raw

    return norm_q_map, raw_q_map


def make_unified_diff(a: str, b: str, from_label: str, to_label: str, context: int) -> str:
    a_lines = (a or "").splitlines()
    b_lines = (b or "").splitlines()
    diff = difflib.unified_diff(
        a_lines, b_lines,
        fromfile=from_label, tofile=to_label,
        n=context, lineterm=""
    )
    return "\n".join(diff)


def levenshtein(a: str, b: str) -> int:
    if a == b:
        return 0
    la, lb = len(a), len(b)
    if la == 0:
        return lb
    if lb == 0:
        return la
    prev = list(range(lb + 1))
    curr = [0] * (lb + 1)
    for i in range(1, la + 1):
        curr[0] = i
        ai = a[i - 1]
        for j in range(1, lb + 1):
            cost = 0 if ai == b[j - 1] else 1
            curr[j] = min(
                prev[j] + 1,
                curr[j - 1] + 1,
                prev[j - 1] + cost
            )
        prev, curr = curr, prev
    return prev[lb]


def write_v1_json(output_dir: Path, pid: int, v1_obj: dict) -> None:
    fname = f"{pid}_v1.json"
    out_file = output_dir / fname
    with out_file.open("w", encoding="utf-8") as f:
        json.dump(v1_obj, f, ensure_ascii=False, indent=2)


def main():
    ap = argparse.ArgumentParser(
        description="Fix GSM8K v1 'question' mismatches by replacing with the original when distance exceeds a threshold; report diffs."
    )
    ap.add_argument("--json-blocks", type=str, required=True, help="Directory containing {id}_v1.json files")
    ap.add_argument("--original", type=str, required=True, help="Path to original GSM8K dataset (JSONL or JSON)")
    ap.add_argument("--output", type=str, required=True, help="Directory to write corrected v1 JSON files")
    ap.add_argument("--unmatched-report", type=str, required=True, help="Where to write diffs/distances/actions (TXT)")
    ap.add_argument("--case-insensitive", action="store_true")
    ap.add_argument("--diff-context", type=int, default=3)
    ap.add_argument("--distance-threshold", type=int, default=2)
    args = ap.parse_args()

    blocks_dir = Path(args.json_blocks)
    original_path = Path(args.original)
    output_dir = Path(args.output)
    report_path = Path(args.unmatched_report)

    output_dir.mkdir(parents=True, exist_ok=True)
    report_path.parent.mkdir(parents=True, exist_ok=True)

    # Load original questions
    orig_q_norm_map, orig_q_raw_map = load_original_question_map(original_path, case_insensitive=args.case_insensitive)

    total = 0
    exact_matches = 0
    corrected_written = 0
    report_entries: List[str] = []

    for file in sorted(blocks_dir.glob("*_v1.json")):
        try:
            pid = int(file.stem.split("_")[0])
        except Exception:
            continue

        try:
            with file.open("r", encoding="utf-8") as f:
                v1 = json.load(f)
        except Exception:
            continue

        total += 1
        q1_raw = v1.get("question", "")
        q1_norm = norm_text(q1_raw, case_insensitive=args.case_insensitive)

        orig_q_norm = orig_q_norm_map.get(pid)
        orig_q_raw = orig_q_raw_map.get(pid)

        if orig_q_norm is None or orig_q_raw is None:
            header = f"=== ID {pid} — missing in original ==="
            diff_q = make_unified_diff(q1_norm, "", "v1.question (norm)", "original.question (norm)", args.diff_context)
            d_q = levenshtein(q1_norm, "")
            report_entries.append(
                "\n".join([
                    header,
                    f"Edit distance (question): {d_q}",
                    "ACTION: kept v1 as-is (no original to replace from) — not written",
                    "Question diff:",
                    diff_q or "(no difference shown)",
                    ""
                ])
            )
            continue

        if q1_norm == orig_q_norm:
            exact_matches += 1
            continue

        # mismatch
        diff_q = make_unified_diff(q1_norm, orig_q_norm, "v1.question (norm)", "original.question (norm)", args.diff_context)
        d_q = levenshtein(q1_norm, orig_q_norm)

        if d_q > args.distance_threshold:
            corrected = dict(v1)
            corrected["question"] = orig_q_raw
            write_v1_json(output_dir, pid, corrected)
            corrected_written += 1
            action = "replaced v1.question with original — written"
        else:
            action = "ignored mismatch — not written"

        report_entries.append(
            "\n".join([
                f"=== ID {pid} — question mismatch ===",
                f"Edit distance (question): {d_q}",
                f"ACTION: {action}",
                "Question diff:",
                diff_q or "(no difference shown)",
                ""
            ])
        )

    # Write report
    with report_path.open("w", encoding="utf-8") as fr:
        if report_entries:
            fr.write("\n".join(report_entries).strip() + "\n")
        else:
            fr.write("All processed v1 questions matched the original (under current normalization).\n")

    print(f"Done. Checked {total} v1 files; wrote {corrected_written} corrected v1 file(s) to: {output_dir}")
    print(f"Exact question matches (no write): {exact_matches}")
    print(f"Report written to: {report_path}")


if __name__ == "__main__":
    main()
