#!/usr/bin/env python3
"""
Add inference results to the problem data file.

This script will:
1. Read the specified JSON problem file
2. Read all txt/log files in the specified directory
3. Match problem IDs by filename, and add the txt content as a "prediction" field to the corresponding problem
4. Save the processed data to the specified output file
"""

import json
import os
import glob
from pathlib import Path
import re
import argparse

def extract_prediction_from_log(log_file_path):
    """
    Extract the prediction content after the last '>>>>>>> Corrected solution:' in the log file.
    
    Args:
        log_file_path: Path to the log file
    
    Returns:
        str: Extracted prediction content, or None if not found
    """
    try:
        with open(log_file_path, 'r', encoding='utf-8') as f:
            content = f.read()
        
        # Find all '>>>>>>> Corrected solution:' markers
        corrected_solution_matches = list(re.finditer(r'>>>>>>> Corrected solution:', content))
        
        if not corrected_solution_matches:
            print(f"Warning: No '>>>>>>> Corrected solution:' marker found in file {log_file_path}")
            return None
        
        # Start from the last match and look backwards for a non-empty solution
        for i in range(len(corrected_solution_matches) - 1, -1, -1):
            match = corrected_solution_matches[i]
            start_pos = match.end()
            
            # Search after the current '>>>>>>> Corrected solution:'
            remaining_content = content[start_pos:]
            
            # Look for the next '>>>>>>>' marker
            next_separator_match = re.search(r'\n>>>>>>>', remaining_content)
            
            if next_separator_match:
                # If found, extract the content before it
                prediction_content = remaining_content[:next_separator_match.start()].strip()
            else:
                # If not found, take until end of file
                prediction_content = remaining_content.strip()
            
            # Remove surrounding quotes if present
            if prediction_content.startswith('"') and prediction_content.endswith('"'):
                prediction_content = prediction_content[1:-1]
            
            # Check if content is non-empty
            if prediction_content and prediction_content.strip():
                return prediction_content
            else:
                print(f"Found empty solution at position {i+1}/{len(corrected_solution_matches)}, continuing backwards...")
        
        # If all '>>>>>>> Corrected solution:' are empty
        print(f"Warning: All '>>>>>>> Corrected solution:' entries are empty in file {log_file_path}")
        return None
        
    except Exception as e:
        print(f"Error: Cannot read log file {log_file_path} - {e}")
        return None

def parse_arguments():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser(
        description="Add inference results to the problem data file",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Example usage:
  python process_predictions.py -j ../problems/HiPhO25/PanPhO_2025.json -l ../logs/gemini-2.5-flash-thinking825/PanPhO_2025 -o PanPhO_2025_with_predictions.json
  python process_predictions.py --json-file ../problems/HiPhO25/APhO_2025.json --log-dir ../logs/claude-sonnet/APhO_2025 --output result.json
        """
    )
    
    parser.add_argument(
        '-j', '--json-file',
        required=True,
        help='Path to the input JSON problem file (required)'
    )
    
    parser.add_argument(
        '-l', '--log-dir',
        required=True,
        help='Path to the log directory containing inference results (required)'
    )
    
    parser.add_argument(
        '-o', '--output',
        default='output_with_predictions.json',
        help='Path to the output file (default: output_with_predictions.json)'
    )
    
    parser.add_argument(
        '--output-dir',
        default='infer_result_with_predictions',
        help='Output directory (default: infer_result_with_predictions)'
    )
    
    return parser.parse_args()

def main():
    # Parse arguments
    args = parse_arguments()
    
    # Setup paths
    json_file_path = args.json_file
    txt_dir_path = args.log_dir
    output_dir = args.output_dir
    output_file_path = os.path.join(output_dir, args.output)
    
    # Print config
    print("=== Configuration ===")
    print(f"JSON file path: {json_file_path}")
    print(f"Log directory path: {txt_dir_path}")
    print(f"Output directory: {output_dir}")
    print(f"Output file: {output_file_path}")
    print()
    
    # Validate inputs
    if not os.path.exists(json_file_path):
        print(f"Error: JSON file does not exist: {json_file_path}")
        return
    
    if not os.path.exists(txt_dir_path):
        print(f"Error: Log directory does not exist: {txt_dir_path}")
        return
    
    if not os.path.isdir(txt_dir_path):
        print(f"Error: Log path is not a directory: {txt_dir_path}")
        return
    
    print(f"Reading JSON file: {json_file_path}")
    
    # Read JSON data
    try:
        with open(json_file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        print(f"Successfully read JSON file, containing {len(data)} records")
    except FileNotFoundError:
        print(f"Error: JSON file not found {json_file_path}")
        return
    except json.JSONDecodeError as e:
        print(f"Error: JSON format error - {e}")
        return
    
    # Get all txt files
    txt_pattern = os.path.join(txt_dir_path, "final_solution_*.txt")
    txt_files = glob.glob(txt_pattern)
    print(f"Found {len(txt_files)} txt files")
    
    # Build filename to content mapping
    predictions = {}
    for txt_file in txt_files:
        filename = os.path.basename(txt_file)
        if filename.startswith("final_solution_") and filename.endswith(".txt"):
            problem_id = filename[len("final_solution_"):-len(".txt")]
            
            try:
                with open(txt_file, 'r', encoding='utf-8') as f:
                    content = f.read().strip()
                predictions[problem_id] = content
                print(f"Read prediction: {problem_id}")
            except Exception as e:
                print(f"Warning: Cannot read file {txt_file} - {e}")
    
    print(f"Successfully read {len(predictions)} predictions")
    
    # Add prediction field to JSON data
    updated_count = 0
    for item in data:
        if not isinstance(item, dict) or 'id' not in item:
            continue
            
        problem_id = item['id']
        if problem_id in predictions:
            item['prediction'] = predictions[problem_id]
            updated_count += 1
            print(f"Added prediction for problem {problem_id}")
        else:
            # If no txt file, try extracting from log file
            log_file_path = os.path.join(txt_dir_path, f"{problem_id}.log")
            if os.path.exists(log_file_path):
                print(f"No txt file found, trying log file: {log_file_path}")
                log_prediction = extract_prediction_from_log(log_file_path)
                if log_prediction:
                    item['prediction'] = log_prediction
                    updated_count += 1
                    print(f"Added prediction from log file for problem {problem_id}")
                else:
                    print(f"Warning: Cannot extract prediction from log file {log_file_path}")
            else:
                print(f"Warning: No txt or log file found for problem {problem_id}")
    
    print(f"Total predictions added: {updated_count}")
    
    # Ensure output directory exists
    os.makedirs(output_dir, exist_ok=True)
    
    # Save processed data
    try:
        with open(output_file_path, 'w', encoding='utf-8') as f:
            json.dump(data, f, ensure_ascii=False, indent=2)
        print(f"Processed data saved to: {output_file_path}")
    except Exception as e:
        print(f"Error: Cannot save file {output_file_path} - {e}")
        return
    
    # Print stats
    print("\n=== Processing Completed ===")
    print(f"Original data records: {len(data)}")
    print(f"Prediction files found: {len(predictions)}")
    print(f"Predictions successfully added: {updated_count}")
    print(f"Output file: {output_file_path}")

if __name__ == "__main__":
    main()
