#!/usr/bin/env python3
"""
Compare two JSONL files for exact equality (line-by-line, JSON-normalized when possible).

Usage:
  python scripts/gpt2/tools/compare_jsonl.py --a /path/to/a.jsonl --b /path/to/b.jsonl

Outputs a short report and exits with code 0 if identical, 1 otherwise.
"""
import argparse
import json
import hashlib
from itertools import zip_longest
from pathlib import Path
from typing import Optional


def md5_of_file(path: Path, chunk_size: int = 1 << 20) -> str:
    h = hashlib.md5()
    with path.open('rb') as f:
        while True:
            chunk = f.read(chunk_size)
            if not chunk:
                break
            h.update(chunk)
    return h.hexdigest()


def normalize_line(line: str) -> Optional[str]:
    s = line.strip()
    if not s:
        return ''
    try:
        obj = json.loads(s)
        # canonical JSON string
        return json.dumps(obj, sort_keys=True, ensure_ascii=False)
    except Exception:
        return s


def compare(a_path: Path, b_path: Path, max_diffs: int = 10) -> int:
    print(f"Comparing:\n  A: {a_path}\n  B: {b_path}\n")

    if not a_path.exists() or not b_path.exists():
        print("One or both files do not exist.")
        return 1

    a_md5 = md5_of_file(a_path)
    b_md5 = md5_of_file(b_path)
    print(f"MD5 A: {a_md5}\nMD5 B: {b_md5}")
    if a_md5 == b_md5:
        print("Files are byte-identical.")
        return 0

    # fall back to streaming compare with JSON normalization
    diffs = []
    total = 0
    with a_path.open('r', encoding='utf-8') as fa, b_path.open('r', encoding='utf-8') as fb:
        for i, (la, lb) in enumerate(zip_longest(fa, fb, fillvalue=None), start=1):
            if la is None or lb is None:
                diffs.append((i, la, lb))
                total += 1
            else:
                na = normalize_line(la)
                nb = normalize_line(lb)
                if na != nb:
                    diffs.append((i, na, nb))
                    total += 1
            if len(diffs) >= max_diffs:
                # keep scanning to get total count but only store limited diffs
                pass

    # count lines
    a_lines = sum(1 for _ in a_path.open('r', encoding='utf-8'))
    b_lines = sum(1 for _ in b_path.open('r', encoding='utf-8'))
    print(f"Lines: A={a_lines}, B={b_lines}")
    print(f"Total differing lines (approx): {total}")

    if total > 0:
        print('\nFirst diffs (up to {})'.format(max_diffs))
        for idx, va, vb in diffs[:max_diffs]:
            print(f"--- Line {idx} ---")
            print("A:", va)
            print("B:", vb)
        return 1

    print("No differences found (after JSON-normalized line compare).")
    return 0


def main():
    p = argparse.ArgumentParser()
    p.add_argument('--a', type=Path, required=True)
    p.add_argument('--b', type=Path, required=True)
    p.add_argument('--max-diffs', type=int, default=10)
    args = p.parse_args()
    rc = compare(args.a, args.b, args.max_diffs)
    raise SystemExit(rc)


if __name__ == '__main__':
    main()


