#!/usr/bin/env python3
"""
Compare two JSONL files ignoring order of records (compares by `prompt` key by default).

Outputs a summary: counts per prompt only-in-A, only-in-B, and common counts. If the multisets
match but orders differ, prints that fact. Also can produce an index mapping from A->B occurrences.

Usage:
 python scripts/gpt2/tools/compare_jsonl_unordered.py --a a.jsonl --b b.jsonl
"""
from collections import Counter, defaultdict
from pathlib import Path
import json
import argparse
from typing import Dict, List


def stream_prompts(path: Path):
    with path.open('r', encoding='utf-8') as f:
        for i, line in enumerate(f):
            line=line.strip()
            if not line:
                continue
            try:
                obj=json.loads(line)
            except Exception:
                # fallback: treat line as raw
                yield i, line
                continue
            prompt = obj.get('prompt', None)
            if prompt is None:
                # try other keys
                prompt = obj.get('prompt_text', '')
            yield i, prompt.strip() if isinstance(prompt, str) else str(prompt)


def main():
    p=argparse.ArgumentParser()
    p.add_argument('--a', type=Path, required=True)
    p.add_argument('--b', type=Path, required=True)
    p.add_argument('--max-preview', type=int, default=10)
    args=p.parse_args()

    a_prompts=[]
    b_prompts=[]
    for i,prompt in stream_prompts(args.a):
        a_prompts.append(prompt)
    for i,prompt in stream_prompts(args.b):
        b_prompts.append(prompt)

    ca=Counter(a_prompts)
    cb=Counter(b_prompts)

    only_a = []
    only_b = []
    for key in set(list(ca.keys())+list(cb.keys())):
        na=ca.get(key,0)
        nb=cb.get(key,0)
        if na>nb:
            only_a.append((key, na-nb))
        elif nb>na:
            only_b.append((key, nb-na))

    print(f"Total lines: A={len(a_prompts)} B={len(b_prompts)}")
    if not only_a and not only_b:
        print("Multisets of prompts are identical (same elements & counts).")
    else:
        print(f"Prompts only in A (count): {len(only_a)}; only in B (count): {len(only_b)}")
        print("First few only-in-A:")
        for k,cnt in only_a[:args.max_preview]:
            print(f"  ({cnt}) {k}")
        print("First few only-in-B:")
        for k,cnt in only_b[:args.max_preview]:
            print(f"  ({cnt}) {k}")

    # If multisets identical, check order
    if not only_a and not only_b:
        if a_prompts == b_prompts:
            print("Files are identical in order.")
        else:
            print("Files contain same prompts but order differs.")
            # produce a mapping: for each index in B, find corresponding index in A in occurrence order
            positions=defaultdict(list)
            for i,p in enumerate(a_prompts): positions[p].append(i)
            # pointer per prompt
            ptrs=defaultdict(int)
            mapping=[]
            for j,p in enumerate(b_prompts):
                lst=positions[p]
                idx=lst[ptrs[p]]
                mapping.append((idx,j))
                ptrs[p]+=1
            print(f"Produced mapping for {len(mapping)} items. First {args.max_preview}:")
            for aidx,bidx in mapping[:args.max_preview]:
                print(f"  A->{aidx}  maps to B->{bidx}")


if __name__=='__main__':
    main()


