import json
import os
import re

DATA_DIR = "./"
SOURCE_FILES = [
    "addsub_50_test.json",
    "aqua_50_test.json",
    "du_50_test.json",
    "gsm8k_50_test.json",
    "lastletter_50_test.json"
]

def load_gt_data():
    gt_dict = {}

    for fname in SOURCE_FILES:
        path = os.path.join(DATA_DIR, fname)
        with open(path, "r", encoding="utf-8") as f:
            dataset = json.load(f)

        for item in dataset:
            q = item["question"].strip()
            backward_correct = item["backward_correct"]
            gt_dict[q] = backward_correct

    return gt_dict


def extract_yes_no(output_str):
    pattern = r"\b(yes|no)\b"
    matches = re.findall(pattern, output_str.lower())

    if not matches:
        return None
    return matches[-1]


def main():
    gt_dict = load_gt_data()

    with open("one_shot_backward_250.json", "r", encoding="utf-8") as f:
        data = json.load(f)

    total = len(data)
    correct = 0
    new_data = []

    for item in data:
        q = item["question"].strip()

        if q not in gt_dict:
            print(f"[Warning] Question not found in GT dataset: {q}")
            continue

        backward_label = gt_dict[q]

        model_text = item.get("model_output", "")
        pred = extract_yes_no(model_text)

        if pred is None:
            backward_output = None
        else:
            backward_output = (pred == "yes")

        is_match = (backward_output == backward_label)

        if is_match:
            correct += 1

        item["backward_label"] = backward_label
        item["backward_output"] = backward_output
        item["match"] = is_match

        new_data.append(item)

    out_path = "one_shot_backward_250_acc.json"
    with open(out_path, "w", encoding="utf-8") as f:
        json.dump(new_data, f, indent=2, ensure_ascii=False)

    accuracy = correct / total if total > 0 else 0
    print("====================================")
    print(f"Total samples: {total}")
    print(f"Correct matches: {correct}")
    print(f"Backward Verification Accuracy: {accuracy:.4f}")
    print(f"Saved new JSON → {out_path}")
    print("====================================")


if __name__ == "__main__":
    main()
