import os
import json
import argparse
from tqdm import tqdm

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--orig-file",
        type=str,
        required=True,
        help="Path to the original predictions JSONL (e.g., coco-random-llava-7b.jsonl)"
    )
    parser.add_argument(
        "--attn-file",
        type=str,
        required=True,
        help="Path to the high-attention predictions JSONL with entropy (e.g., coco-random-llava-7b-HighAttn.jsonl)"
    )
    parser.add_argument(
        "--output-file",
        type=str,
        default="changed_preds.jsonl",
        help="Output JSONL file for samples whose pred changed"
    )
    args = parser.parse_args()

    # load both files
    orig = [json.loads(line) for line in open(os.path.expanduser(args.orig_file), "r", encoding="utf-8")]
    high = [json.loads(line) for line in open(os.path.expanduser(args.attn_file), "r", encoding="utf-8")]

    assert len(orig) == len(high), "Input files must have the same number of lines"

    changed = []
    for a, b in tqdm(zip(orig, high), total=len(orig), desc="Checking predictions"):
        assert a["question_id"] == b["question_id"], f"ID mismatch: {a['question_id']} vs {b['question_id']}"
        pred_a = a["pred"].strip()
        pred_b = b["pred"].strip()
        if pred_a.lower() != pred_b.lower():
            changed.append({
                "question_id":      a["question_id"],
                "text":             a["text"],
                "label":            a["label"],
                "pred_original":    pred_a,
                "pred_highattn":    pred_b,
                "probs_original":   a["probs"],
                "probs_highattn":   b["probs"],
                "entropy":          b.get("entropy")
            })

    # write out changed cases
    with open(os.path.expanduser(args.output_file), "w", encoding="utf-8") as fout:
        for rec in changed:
            fout.write(json.dumps(rec, ensure_ascii=False) + "\n")

    print(f"Saved {len(changed)} changed-prediction samples to {args.output_file}")
