import json
import argparse
import os


def compare_yes_probs_detailed(baseline_path, attention_path):
    """Compare the change of ‘yes’ probabilities between baseline and attention-modified results."""
    # load baseline
    baseline_data = {}
    with open(baseline_path, "r", encoding="utf-8") as f:
        for line in f:
            record = json.loads(line)
            qid = record["question_id"]
            baseline_data[qid] = {
                "label": record["label"].strip().lower(),
                "pred": record["pred"].strip().lower(),
                "probs": float(record["probs"]),
            }

    # detailed stats
    stats = {
        "yes_correct_increase": 0,
        "yes_correct_decrease": 0,
        "yes_incorrect_increase": 0,
        "yes_incorrect_decrease": 0,
        "no_correct_increase": 0,
        "no_correct_decrease": 0,
        "no_incorrect_increase": 0,
        "no_incorrect_decrease": 0,
    }

    # iterate attention file
    with open(attention_path, "r", encoding="utf-8") as f:
        for line in f:
            record = json.loads(line)
            qid = record["question_id"]

            baseline_record = baseline_data.get(qid)
            if baseline_record is None:
                continue

            # compute yes probabilities
            base_yes_prob = (
                baseline_record["probs"]
                if baseline_record["pred"] == "yes"
                else 1 - baseline_record["probs"]
            )
            attention_yes_prob = (
                record["probs"] if record["pred"].strip().lower() == "yes" else 1 - record["probs"]
            )

            delta = attention_yes_prob - base_yes_prob
            label = record["label"].strip().lower()
            # pred = record["pred"].strip().lower()
            pred = baseline_record["pred"].strip().lower()

            if label == "yes" and pred == "yes":
                stats["yes_correct_increase" if delta > 0 else "yes_correct_decrease"] += 1
            elif label == "no" and pred == "yes":
                stats["yes_incorrect_increase" if delta > 0 else "yes_incorrect_decrease"] += 1
            elif label == "no" and pred == "no":
                stats["no_correct_increase" if delta > 0 else "no_correct_decrease"] += 1
            elif label == "yes" and pred == "no":
                stats["no_incorrect_increase" if delta > 0 else "no_incorrect_decrease"] += 1

    return stats


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--baseline", type=str, required=True, help="Path to the baseline JSONL file")
    parser.add_argument("--attention", type=str, required=True, help="Path to the attention-modified JSONL file")
    args = parser.parse_args()

    baseline_path = args.baseline
    attention_path = args.attention

    if not os.path.exists(baseline_path):
        raise FileNotFoundError(f"Baseline file not found at: {baseline_path}")
    if not os.path.exists(attention_path):
        raise FileNotFoundError(f"Attention file not found at: {attention_path}")

    stats = compare_yes_probs_detailed(baseline_path, attention_path)

    # print results
    print("【预测 yes 且正确】")
    print(f"- yes 概率增加的样本数: {stats['yes_correct_increase']}")
    print(f"- yes 概率减少的样本数: {stats['yes_correct_decrease']}\n")

    print("【预测 yes 且错误】")
    print(f"- yes 概率增加的样本数: {stats['yes_incorrect_increase']}")
    print(f"- yes 概率减少的样本数: {stats['yes_incorrect_decrease']}\n")

    print("【预测 no 且正确】")
    print(f"- yes 概率增加的样本数: {stats['no_correct_increase']}")
    print(f"- yes 概率减少的样本数: {stats['no_correct_decrease']}\n")

    print("【预测 no 且错误】")
    print(f"- yes 概率增加的样本数: {stats['no_incorrect_increase']}")
    print(f"- yes 概率减少的样本数: {stats['no_incorrect_decrease']}\n")
