import json
import re
import argparse
from tqdm import tqdm

pattern = r"""
    \[\s*(?P<star_answer>[A-E])\s*\]                            |  # [A]
    \[ANSWER(?::\s*(?P<colon_answer>[A-E]))?\]                 # [ANSWER: B] or [ANSWER]
    \*{0,2}\s*\(?(?P<paren_answer>[A-E])?\)?                   |  # **(C), *(D)
    ANSWER\s*\(?(?P<bare_answer>[A-E])\)?                         # ANSWER (D)
"""

def extract_answer(text):
    matches = list(re.finditer(pattern, text, re.VERBOSE))
    if matches:
        last = matches[-1]
        return (
            last.group("star_answer")
            or last.group("colon_answer")
            or last.group("paren_answer")
            or last.group("bare_answer")
        )
    return None

def main():
    parser = argparse.ArgumentParser(description="Evaluate reflection accuracy from inference results.")
    parser.add_argument("--input", type=str, required=True, help="Path to input JSON file (inference results).")
    args = parser.parse_args()

    with open(args.input, "r", encoding="utf-8") as f:
        dataset = json.load(f)

    filtered_results = []
    correct_to_wrong = 0
    wrong_to_correct = 0
    initial_correct = 0
    final_correct = 0

    key_name = "with_steering_output_1.0"

    for item in tqdm(dataset, desc="Processing..."):
        correct_answer = item.get("answerKey")
        answer_1 = extract_answer(item.get("baseline_output", ""))
        answer_2 = extract_answer(item.get(key_name, ""))

        if answer_1 == correct_answer:
            initial_correct += 1
        if answer_2 == correct_answer:
            final_correct += 1
            item["reflection_output_answer"] = answer_2
            filtered_results.append(item)

        if answer_1 == correct_answer and answer_2 != correct_answer:
            correct_to_wrong += 1
        elif answer_1 != correct_answer and answer_2 == correct_answer:
            wrong_to_correct += 1

    print(f"Total samples: {len(dataset)}")
    # print(f"No reflection correct count: {initial_correct}")
    print(f"With reflection correct count: {final_correct}")
    print(f"Accuracy after reflection: {final_correct / len(dataset):.2%}")
    # print(f"Correct → Wrong: {correct_to_wrong}")
    # print(f"Wrong → Correct: {wrong_to_correct}")

if __name__ == "__main__":
    main()