#!/usr/bin/env python3
"""
Script to fix the answer field in the SmolTraces-R1 dataset by stripping the \boxed{} notation.
"""

import os
import json
import re
import logging
import sys
import argparse

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout),
        logging.FileHandler("fix_boxed_answers.log")
    ]
)

def extract_boxed_content(text):
    """
    Extract the content from inside \boxed{} notation.
    
    Args:
        text: The text containing boxed notation
        
    Returns:
        The content inside the boxed notation
    """
    # Look for boxed content with a more robust regex
    boxed_matches = re.findall(r'\\boxed{([^{}]+(?:{[^{}]*}[^{}]*)*)}', text)
    
    if boxed_matches:
        # Return the last one, which is typically the final answer
        return boxed_matches[-1]
    
    # Try a simpler pattern as a fallback
    simple_matches = re.findall(r'\\boxed{([^}]*)}', text)
    if simple_matches:
        return simple_matches[-1]
    
    # If no boxed content found, return the original text
    return text

def fix_samples(input_dir, output_dir=None):
    """
    Process all sample files in the input directory and fix the answer field.
    
    Args:
        input_dir: Directory containing the sample files
        output_dir: Directory to save fixed files (defaults to input_dir)
    """
    if output_dir is None:
        output_dir = input_dir
    
    os.makedirs(output_dir, exist_ok=True)
    
    # Fix individual sample files
    sample_files = [f for f in os.listdir(input_dir) if f.startswith("sample_") and f.endswith(".json")]
    for filename in sample_files:
        input_path = os.path.join(input_dir, filename)
        output_path = os.path.join(output_dir, filename)
        
        try:
            with open(input_path, "r") as f:
                sample = json.load(f)
            
            # Fix the answer field
            if isinstance(sample.get("answer"), str) and "\\boxed" in sample["answer"]:
                original_answer = sample["answer"]
                sample["answer"] = extract_boxed_content(original_answer)
                logging.info(f"Fixed answer in {filename}: '{original_answer}' -> '{sample['answer']}'")
            
            # Save the fixed sample
            with open(output_path, "w") as f:
                json.dump(sample, f, indent=2)
        
        except Exception as e:
            logging.error(f"Error processing {filename}: {str(e)}")
    
    # Fix combined samples files
    combined_files = ["samples.json", "samples.jsonl"]
    for filename in combined_files:
        input_path = os.path.join(input_dir, filename)
        output_path = os.path.join(output_dir, filename)
        
        if not os.path.exists(input_path):
            continue
        
        try:
            if filename.endswith(".json"):
                # Process JSON file
                with open(input_path, "r") as f:
                    samples = json.load(f)
                
                # Fix each sample
                fixed_samples = []
                for i, sample in enumerate(samples):
                    if isinstance(sample.get("answer"), str) and "\\boxed" in sample["answer"]:
                        original_answer = sample["answer"]
                        sample["answer"] = extract_boxed_content(original_answer)
                        logging.info(f"Fixed answer in samples.json[{i}]: '{original_answer}' -> '{sample['answer']}'")
                    fixed_samples.append(sample)
                
                # Save fixed samples
                with open(output_path, "w") as f:
                    json.dump(fixed_samples, f, indent=2)
                
            elif filename.endswith(".jsonl"):
                # Process JSONL file
                fixed_lines = []
                with open(input_path, "r") as f:
                    for i, line in enumerate(f):
                        try:
                            sample = json.loads(line.strip())
                            if isinstance(sample.get("answer"), str) and "\\boxed" in sample["answer"]:
                                original_answer = sample["answer"]
                                sample["answer"] = extract_boxed_content(original_answer)
                                logging.info(f"Fixed answer in samples.jsonl[{i}]: '{original_answer}' -> '{sample['answer']}'")
                            fixed_lines.append(json.dumps(sample))
                        except Exception as e:
                            logging.error(f"Error processing line {i} in {filename}: {str(e)}")
                            fixed_lines.append(line.strip())
                
                # Save fixed lines
                with open(output_path, "w") as f:
                    f.write("\n".join(fixed_lines))
                
        except Exception as e:
            logging.error(f"Error processing {filename}: {str(e)}")
    
    logging.info(f"Finished fixing samples in {input_dir}")

def main():
    """Main entry point."""
    parser = argparse.ArgumentParser(description="Fix answer field in SmolTraces-R1 dataset")
    parser.add_argument("--input_dir", default="datasets/SmolTraces-R1", help="Directory containing sample files")
    parser.add_argument("--output_dir", help="Directory to save fixed files (defaults to input_dir)")
    args = parser.parse_args()
    
    fix_samples(args.input_dir, args.output_dir)

if __name__ == "__main__":
    main() 