import json
import re
from tqdm import tqdm

# ========== JSON I/O Utilities ==========

def load_json(path):
    """Load a JSON file from the given path."""
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

def save_json(data, path):
    """Save data as a JSON file to the given path."""
    with open(path, "w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False, indent=4)


# ========== Reflection Answer Extraction ==========

# Regex pattern for different answer formats found in reflection text
ANSWER_PATTERN = r"""
    \*\*\[(?P<star_answer>[A-E])\]                             |  # **[A]
    \[ANSWER(?::\s*(?P<colon_answer>[A-E]))?\]                 # [ANSWER: B] or [ANSWER]
    \*{0,2}\s*\(?(?P<paren_answer>[A-E])?\)?                   |  # (C), [ANSWER] (C), or **(C)
    ANSWER\s*\(?(?P<bare_answer>[A-E])\)?                         # ANSWER (D)
"""

def extract_reflection_answer(text, pattern=ANSWER_PATTERN):
    """
    Extract the final answer (A-E) from the reflection output text.
    The regex covers multiple possible answer formats.
    """
    matches = list(re.finditer(pattern, text, re.VERBOSE))
    if matches:
        last = matches[-1]
        return (
            last.group("star_answer")
            or last.group("colon_answer")
            or last.group("paren_answer")
            or last.group("bare_answer")
        )
    return None


# ========== Data Filtering ==========

def filter_correct_reflections(dataset, answer_field="correct_answer", reflection_field="reflection"):
    """
    Filter reflection results by comparing extracted answers with the ground-truth answer.
    """
    filtered = []
    for item in tqdm(dataset, desc="Filtering reflections..."):
        reflection_text = item.get(reflection_field, "")
        reflection_answer = extract_reflection_answer(reflection_text)
        if reflection_answer == item.get(answer_field):
            item["reflection_answer"] = reflection_answer
            filtered.append(item)
    return filtered