#!/usr/bin/env python3
"""
fix_entailment_bank_v1_from_blocks.py

Read Version-1 JSON files directly from a directory ({id}_v1.json),
compare their ("instruction","input") with the ORIGINAL dataset
(JSONL or JSON list). If mismatched, optionally replace the v1["input"]
with the ORIGINAL raw input when edit distance exceeds a threshold and
write ONLY those corrected v1s to --output.

Behavior:
- If (instruction,input) match after normalization -> do NOT write anything.
- Else:
  * Compute edit distances (normalized) for:
        - instruction vs original.instruction
        - input vs original.input
        - combined: (instruction + "\n\n" + input) vs original combined
  * If input_dist > threshold OR combined_dist > threshold:
        -> replace v1["input"] with ORIGINAL raw input, then write corrected v1 as {id}_v1.json to --output.
    Else:
        -> ignore mismatch and do NOT write anything.
- Always write a text report with:
    * edit distances
    * unified diffs (normalized)
    * ACTION taken per item

Normalization:
- Strip leading/trailing whitespace
- Collapse internal whitespace runs to a single space
- Optional case-insensitive via --case-insensitive

Usage:
  python fix_entailment_bank_v1_from_blocks.py \
    --json-blocks json_blocks/ \
    --original original.jsonl \
    --output corrected_v1/ \
    --unmatched-report unmatched.txt \
    --distance-threshold 2
"""

import argparse
import json
import re
import difflib
from pathlib import Path
from typing import Dict, Tuple, Optional, List, Any

WS_RE = re.compile(r"\s+")


def norm_text(s: Optional[str], *, case_insensitive: bool) -> str:
    if s is None:
        return ""
    s = s.strip()
    s = WS_RE.sub(" ", s)
    if case_insensitive:
        s = s.lower()
    return s


def load_original_map(
    path: Path,
    case_insensitive: bool
) -> Tuple[Dict[int, Tuple[str, str]], Dict[int, Tuple[str, str]]]:
    """
    Load original JSON/JSONL and return two maps:
        norm_map[id] -> (instruction_norm, input_norm)
        raw_map[id]  -> (instruction_raw,  input_raw)

    If an object has an integer "id" field, use it.
    Otherwise, use line index (0-based) or list index.
    """
    objs: List[Any] = []

    if path.suffix.lower() == ".jsonl":
        with path.open("r", encoding="utf-8") as f:
            for idx, line in enumerate(f):
                line = line.strip()
                if not line:
                    continue
                try:
                    obj = json.loads(line)
                except Exception:
                    continue
                objs.append(obj)
    elif path.suffix.lower() == ".json":
        with path.open("r", encoding="utf-8") as f:
            data = json.load(f)
            if not isinstance(data, list):
                raise ValueError(f"Expected a list in {path}, got {type(data)}")
            objs = data
    else:
        raise ValueError(f"Unsupported file extension for --original: {path.suffix}")

    norm_map: Dict[int, Tuple[str, str]] = {}
    raw_map: Dict[int, Tuple[str, str]] = {}

    for idx, obj in enumerate(objs):
        if not isinstance(obj, dict):
            continue
        obj_id = obj.get("id")
        key = obj_id if isinstance(obj_id, int) else idx

        ins_raw = obj.get("instruction", "")
        inp_raw = obj.get("input", "")

        ins_norm = norm_text(ins_raw, case_insensitive=case_insensitive)
        inp_norm = norm_text(inp_raw, case_insensitive=case_insensitive)

        norm_map[key] = (ins_norm, inp_norm)
        raw_map[key] = (ins_raw, inp_raw)

    return norm_map, raw_map


def make_unified_diff(a: str, b: str, from_label: str, to_label: str, context: int) -> str:
    a_lines = (a or "").splitlines()
    b_lines = (b or "").splitlines()
    diff = difflib.unified_diff(
        a_lines, b_lines,
        fromfile=from_label, tofile=to_label,
        n=context, lineterm=""
    )
    return "\n".join(diff)


def levenshtein(a: str, b: str) -> int:
    """O(len(a)*len(b)) Levenshtein (insert/delete/substitute = 1)."""
    if a == b:
        return 0
    la, lb = len(a), len(b)
    if la == 0:
        return lb
    if lb == 0:
        return la
    prev = list(range(lb + 1))
    curr = [0] * (lb + 1)
    for i in range(1, la + 1):
        curr[0] = i
        ai = a[i - 1]
        for j in range(1, lb + 1):
            cost = 0 if ai == b[j - 1] else 1
            curr[j] = min(
                prev[j] + 1,      # deletion
                curr[j - 1] + 1,  # insertion
                prev[j - 1] + cost  # substitution
            )
        prev, curr = curr, prev
    return prev[lb]


def write_v1_json(output_dir: Path, pid: int, v1_obj: dict) -> None:
    """Write a single v1 JSON object to {pid}_v1.json in output_dir."""
    fname = f"{pid}_v1.json"
    out_file = output_dir / fname
    with out_file.open("w", encoding="utf-8") as f:
        json.dump(v1_obj, f, ensure_ascii=False, indent=2)


def main():
    ap = argparse.ArgumentParser(
        description="Fix EntailmentBank v1 input mismatches by replacing with the original when distance exceeds a threshold; report diffs."
    )
    ap.add_argument("--json-blocks", type=str, required=True, help="Directory containing {id}_v1.json files")
    ap.add_argument("--original", type=str, required=True, help="Path to original dataset (JSONL or JSON)")
    ap.add_argument("--output", type=str, required=True, help="Directory to write ONLY corrected v1 JSON files as {id}_v1.json")
    ap.add_argument("--unmatched-report", type=str, required=True, help="Where to write diffs/distances/actions (TXT)")
    ap.add_argument("--case-insensitive", action="store_true", help="Compare case-insensitively")
    ap.add_argument("--diff-context", type=int, default=3, help="Unified diff context lines")
    ap.add_argument("--distance-threshold", type=int, default=2, help="Threshold for input/combined edit distance to trigger replacement")
    args = ap.parse_args()

    blocks_dir = Path(args.json_blocks)
    original_path = Path(args.original)
    output_dir = Path(args.output)
    report_path = Path(args.unmatched_report)

    # Prepare output locations
    output_dir.mkdir(parents=True, exist_ok=True)
    report_path.parent.mkdir(parents=True, exist_ok=True)

    # Load original maps
    orig_norm_map, orig_raw_map = load_original_map(original_path, case_insensitive=args.case_insensitive)
    if not orig_norm_map:
        print(f"Warning: no usable rows in original file: {original_path}")

    matches = 0            # exact (instruction,input) matches (no write)
    total = 0
    corrected_written = 0  # only corrected items are written
    report_entries: List[str] = []

    for file in sorted(blocks_dir.glob("*_v1.json")):
        # Expect filenames like "123_v1.json" -> pid = 123
        try:
            pid = int(file.stem.split("_")[0])
        except Exception:
            continue

        # Load v1 object
        try:
            with file.open("r", encoding="utf-8") as f:
                v1 = json.load(f)
        except Exception:
            continue

        if not isinstance(v1, dict):
            continue

        total += 1

        # v1 raw + normalized fields
        ins1_raw = v1.get("instruction", "")
        inp1_raw = v1.get("input", "")
        ins1_norm = norm_text(ins1_raw, case_insensitive=args.case_insensitive)
        inp1_norm = norm_text(inp1_raw, case_insensitive=args.case_insensitive)

        # Original lookup
        orig_norm = orig_norm_map.get(pid)
        orig_raw = orig_raw_map.get(pid)

        if orig_norm is None or orig_raw is None:
            header = f"=== ID {pid} — missing in original ==="
            diff_ins = make_unified_diff(ins1_norm, "", "v1.instruction (norm)", "original.instruction (norm)", args.diff_context)
            diff_inp = make_unified_diff(inp1_norm, "", "v1.input (norm)", "original.input (norm)", args.diff_context)
            d_ins = levenshtein(ins1_norm, "")
            d_inp = levenshtein(inp1_norm, "")
            d_comb = levenshtein(f"{ins1_norm}\n\n{inp1_norm}", "")
            report_entries.append(
                "\n".join([
                    header,
                    f"Edit distance (instruction/input/combined): {d_ins} / {d_inp} / {d_comb}",
                    "ACTION: kept v1 as-is (no original to replace from) — not written",
                    "Instruction diff:",
                    diff_ins or "(no difference shown)",
                    "",
                    "Input diff:",
                    diff_inp or "(no difference shown)",
                    ""
                ])
            )
            continue

        orig_ins_norm, orig_inp_norm = orig_norm
        orig_ins_raw,  orig_inp_raw  = orig_raw

        if ins1_norm == orig_ins_norm and inp1_norm == orig_inp_norm:
            # Perfect match — do not write
            matches += 1
            continue

        # Mismatch: compute distances & diffs
        diff_ins = make_unified_diff(ins1_norm, orig_ins_norm, "v1.instruction (norm)", "original.instruction (norm)", args.diff_context)
        diff_inp = make_unified_diff(inp1_norm, orig_inp_norm, "v1.input (norm)", "original.input (norm)", args.diff_context)
        d_ins = levenshtein(ins1_norm, orig_ins_norm)
        d_inp = levenshtein(inp1_norm, orig_inp_norm)
        d_comb = levenshtein(f"{ins1_norm}\n\n{inp1_norm}", f"{orig_ins_norm}\n\n{orig_inp_norm}")

        # If input or combined edit distance > threshold => replace input with original's raw input
        # Else => ignore mismatch and do not write
        if d_inp > args.distance_threshold or d_comb > args.distance_threshold:
            corrected = dict(v1)
            corrected["input"] = orig_inp_raw
            write_v1_json(output_dir, pid, corrected)
            corrected_written += 1
            action = "replaced v1.input with original input — written"
        else:
            action = "ignored mismatch — not written"

        report_entries.append(
            "\n".join([
                f"=== ID {pid} — text mismatch ===",
                f"Edit distance (instruction/input/combined): {d_ins} / {d_inp} / {d_comb}",
                f"ACTION: {action}",
                "Instruction diff:",
                diff_ins or "(no difference shown)",
                "",
                "Input diff:",
                diff_inp or "(no difference shown)",
                ""
            ])
        )

    # Write report
    with report_path.open("w", encoding="utf-8") as fr:
        if report_entries:
            fr.write("\n".join(report_entries).strip() + "\n")
        else:
            fr.write("All processed v1 records matched the original (under current normalization).\n")

    print(f"Done. Checked {total} v1 files; wrote {corrected_written} corrected v1 file(s) to: {output_dir}")
    print(f"Exact matches (no write): {matches}")
    print(f"Report written to: {report_path}")


if __name__ == "__main__":
    main()
