#!/usr/bin/env python3

import json
import sys
from typing import Optional

def last_boxed_only_string(string: str) -> Optional[str]:
    """Extract the last boxed expression from a string."""
    idx = string.rfind("\\boxed")
    if "\\boxed " in string:
        return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
    if idx < 0:
        idx = string.rfind("\\fbox")
        if idx < 0:
            return None

    i = idx
    right_brace_idx = None
    num_left_braces_open = 0
    while i < len(string):
        if string[i] == "{":
            num_left_braces_open += 1
        if string[i] == "}":
            num_left_braces_open -= 1
            if num_left_braces_open == 0:
                right_brace_idx = i
                break
        i += 1

    if right_brace_idx is None:
        retval = None
    else:
        retval = string[idx : right_brace_idx + 1]

    return retval

def remove_boxed(s: str) -> str:
    """Remove the boxed wrapper from a string."""
    if "\\boxed " in s:
        left = "\\boxed "
        assert s[: len(left)] == left
        return s[len(left) :]

    left = "\\boxed{"
    assert s[: len(left)] == left
    assert s[-1] == "}"
    return s[len(left) : -1]

def extract_answer_from_solution(solution: str) -> Optional[str]:
    """Extract answer from solution using the boxed extraction logic."""
    try:
        answer = remove_boxed(last_boxed_only_string(solution))
        return answer
    except:
        return None

def process_hendrycks_math_dataset(input_file: str, output_file: str):
    """Process the Hendrycks Math dataset to add answer field."""
    print(f"Loading dataset from {input_file}...")
    
    processed_entries = []
    total_entries = 0
    successful_extractions = 0
    
    with open(input_file, 'r', encoding='utf-8') as f:
        for line_num, line in enumerate(f, 1):
            try:
                entry = json.loads(line.strip())
                total_entries += 1
                
                # Extract answer from solution field
                if 'solution' in entry:
                    answer = extract_answer_from_solution(entry['solution'])
                    entry['answer'] = answer
                    if answer is not None:
                        successful_extractions += 1
                else:
                    entry['answer'] = None
                    assert False, f"Entry {line_num} has no 'solution' field"
                
                assert entry['answer'] is not None, f"Entry {line_num} has no 'answer' field"

                processed_entries.append(entry)
                
                # Progress update every 1000 entries
                if total_entries % 1000 == 0:
                    print(f"Processed {total_entries} entries...")
                    
            except json.JSONDecodeError as e:
                print(f"Error parsing JSON at line {line_num}: {e}")
                continue
            except Exception as e:
                print(f"Error processing entry at line {line_num}: {e}")
                continue
    
    print(f"Total entries processed: {total_entries}")
    print(f"Successful answer extractions: {successful_extractions}")
    print(f"Success rate: {successful_extractions/total_entries*100:.2f}%")
    
    # Save processed dataset
    print(f"Saving processed dataset to {output_file}...")
    with open(output_file, 'w', encoding='utf-8') as f:
        for entry in processed_entries:
            json.dump(entry, f, ensure_ascii=False)
            f.write('\n')
    
    print("Dataset processing completed successfully!")
    
    # Show a few examples
    print("\nSample processed entries:")
    for i, entry in enumerate(processed_entries[:3]):
        print(f"\nEntry {i+1}:")
        print(f"  Problem: {entry.get('problem', 'N/A')[:100]}...")
        print(f"  Solution: {entry.get('solution', 'N/A')[:100]}...")
        print(f"  Extracted Answer: {entry.get('answer', 'N/A')}")

def main():
    input_file = "datasets/hendrycks_math_test_all.jsonl"
    output_file = "datasets/hendrycks_math_test_all_with_answers.jsonl"
    
    try:
        process_hendrycks_math_dataset(input_file, output_file)
    except FileNotFoundError:
        print(f"Error: Input file '{input_file}' not found.")
        print("Please make sure the file exists in the datasets directory.")
        sys.exit(1)
    except Exception as e:
        print(f"Error processing dataset: {e}")
        sys.exit(1)

if __name__ == "__main__":
    main() 