#!/usr/bin/env python3
"""
Script to add \boxed{} formatting to final answers in physics reasoning outputs.
Extracts the final answer(s) from `prediction` and formats them with \boxed{}.
"""

import json
import os
import logging
from typing import Dict, List, Any
import openai
from openai import OpenAI
import time
import requests
from dotenv import load_dotenv
import argparse

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


def create_answer_extraction_prompt(context: str, question: str, prediction: str) -> str:
    """
    Create the prompt used to extract and format final answers.
    """
    prompt = f"""You are a physics problem answer extraction expert. Your task is to add \\boxed{{}} formatting to the final answers in the given solution while keeping the complete solution process intact.

**Physics Problem Context:** {context}

**Specific Question:** {question}

**Solution Process:** {prediction}

**Requirements:**
1. Keep the ENTIRE solution process exactly as provided
2. Identify where the final answer(s) appear in the solution
3. Add \\boxed{{}} formatting around the final answer(s) ONLY
4. Do NOT remove any content from the original solution
5. Do NOT add extra explanations or modify the solution steps
6. If there are multiple sub-questions, add \\boxed{{}} to each final answer
7. Maintain the accuracy of mathematical expressions and original formatting

**What to do:**
- Read through the complete solution
- Identify the final numerical values, expressions, or conclusions that answer the question
- Add \\boxed{{}} around those final answers while keeping everything else unchanged
- Output the complete solution with \\boxed{{}} added to the appropriate places

Please provide the complete solution with \\boxed{{}} formatting added to the final answers:"""
    return prompt


def call_openai_api(prompt: str, api_key: str, base_url: str = None, model: str = "gpt-4", max_retries: int = 3) -> str:
    """
    Call the OpenAI API or a compatible API using `requests` to send the payload directly.
    """
    # Handle base_url to ensure the correct endpoint
    if base_url and not base_url.endswith('/chat/completions'):
        if base_url.endswith('/v1'):
            api_url = f"{base_url}/chat/completions"
        else:
            api_url = f"{base_url}/v1/chat/completions"
    else:
        api_url = base_url or "https://api.openai.com/v1/chat/completions"

    logger.info(f"API URL: {api_url}")

    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json"
    }

    payload = {
        "model": model,
        "messages": [
            {"role": "user", "content": prompt}
        ],
        "temperature": 0.1,
        "max_output_tokens": 16384
    }

    for attempt in range(max_retries):
        try:
            response = requests.post(api_url, headers=headers, json=payload, timeout=120)
            response.raise_for_status()

            result = response.json()
            return result["choices"][0]["message"]["content"].strip()

        except Exception as e:
            logger.warning(f"API call failed, attempt {attempt + 1}: {str(e)}")
            if attempt < max_retries - 1:
                time.sleep(2 ** attempt)  # Exponential backoff
            else:
                raise e

    return ""


def process_prediction(item: Dict[str, Any], api_key: str, base_url: str = None, model: str = "gpt-4") -> Dict[str, Any]:
    """
    Process a single prediction entry and add \boxed{} formatting to the final answer(s).
    """
    try:
        context = item.get('context', '')
        question = item.get('question', '')
        prediction = item.get('prediction', '')

        if not prediction:
            logger.warning(f"Item {item.get('id', 'unknown')} has no 'prediction' field")
            return item

        # Create prompt
        prompt = create_answer_extraction_prompt(context, question, prediction)

        # Call API to get the formatted answer
        formatted_answer = call_openai_api(prompt, api_key, base_url, model)

        # Create a new item, keeping only the updated answer in 'prediction'
        new_item = item.copy()
        new_item['prediction'] = formatted_answer if formatted_answer else prediction

        logger.info(f"Successfully processed item: {item.get('id', 'unknown')}")
        return new_item

    except Exception as e:
        logger.error(f"Error processing item {item.get('id', 'unknown')}: {str(e)}")
        return item


def parse_arguments():
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser(
        description="Add \\boxed{} formatting to final answers in physics reasoning results.",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""Examples:
  python add_boxed_answers.py -i predictions.json -o result.json
  python add_boxed_answers.py --input infer_result_with_predictions/PanPhO_2025_with_predictions.json --output-dir boxed_results --output-file final.json
  python add_boxed_answers.py -i data.json -o output.json --model gpt-4 --api-key your_key --base-url https://api.example.com
        """
    )

    parser.add_argument(
        '-i', '--input',
        required=True,
        help='Path to the input JSON file (must contain "prediction" fields) (required)'
    )

    parser.add_argument(
        '--output-dir',
        default='infer_result_with_predictions_boxed',
        help='Output directory (default: infer_result_with_predictions_boxed)'
    )

    parser.add_argument(
        '-o', '--output-file',
        help='Output file name (default: generated based on input file name)'
    )

    parser.add_argument(
        '--model',
        default='gemini-2.5-flash',
        help='Model to use (default: gemini-2.5-flash)'
    )

    parser.add_argument(
        '--api-key',
        help='API key (if not provided, will read from environment variable OPENAI_API_KEY)'
    )

    parser.add_argument(
        '--base-url',
        help='API Base URL (if not provided, will read from environment variable OPENAI_API_BASE, or use default)'
    )

    parser.add_argument(
        '--env-file',
        default='.env',
        help='Path to the environment variables file (default: .env)'
    )

    return parser.parse_args()


def main():
    """
    Main entry point.
    """
    # Parse command-line arguments
    args = parse_arguments()

    # Load .env file
    load_dotenv(args.env_file)

    # Configure input file
    input_file = args.input

    # Configure output file paths
    output_dir = args.output_dir
    if args.output_file:
        output_file = os.path.join(output_dir, args.output_file)
    else:
        # Generate output file name based on input file name
        input_basename = os.path.basename(input_file)
        input_name, input_ext = os.path.splitext(input_basename)
        output_filename = f"{input_name}_boxed{input_ext}"
        output_file = os.path.join(output_dir, output_filename)

    # Get API configuration from args or environment
    api_key = args.api_key or os.getenv('OPENAI_API_KEY')
    base_url = args.base_url or os.getenv('OPENAI_API_BASE')  # Optional; if not set, default is used
    model = args.model or os.getenv('OPENAI_MODEL', 'gemini-2.5-flash')

    # Print configuration
    print("=== Configuration ===")
    print(f"Input file: {input_file}")
    print(f"Output directory: {output_dir}")
    print(f"Output file: {output_file}")
    print(f"Model: {model}")
    if base_url:
        print(f"Base URL: {base_url}")
    print(f"Env file: {args.env_file}")
    print()

    # Validate input file existence
    if not os.path.exists(input_file):
        logger.error(f"Input file does not exist: {input_file}")
        return

    if not api_key:
        logger.error("Please provide an API key via --api-key or set OPENAI_API_KEY in the .env file")
        return

    # Ensure output directory exists
    os.makedirs(output_dir, exist_ok=True)

    # Read input file
    logger.info(f"Reading input file: {input_file}")
    with open(input_file, 'r', encoding='utf-8') as f:
        data = json.load(f)

    # Process data
    logger.info(f"Starting processing of {len(data)} entries")
    processed_data = []

    for i, item in enumerate(data):
        if isinstance(item, dict) and 'prediction' in item:
            logger.info(f"Progress: {i + 1}/{len(data)}")
            processed_item = process_prediction(item, api_key, base_url, model)
            processed_data.append(processed_item)

            # Save intermediate results every 10 entries (to avoid data loss)
            if (i + 1) % 10 == 0:
                logger.info(f"Intermediate save: {i + 1}/{len(data)}")
                with open(output_file + f".tmp_{i + 1}", 'w', encoding='utf-8') as f:
                    json.dump(processed_data, f, ensure_ascii=False, indent=2)
        else:
            processed_data.append(item)

    # Save final results
    logger.info(f"Saving results to: {output_file}")
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(processed_data, f, ensure_ascii=False, indent=2)

    logger.info("Processing complete!")

    # Clean up temporary files
    output_basename = os.path.basename(output_file)
    tmp_prefix = f"{output_basename}.tmp_"
    for tmp_file in os.listdir(output_dir):
        if tmp_file.startswith(tmp_prefix):
            tmp_path = os.path.join(output_dir, tmp_file)
            try:
                os.remove(tmp_path)
                logger.info(f"Cleaned up temp file: {tmp_path}")
            except Exception as e:
                logger.warning(f"Failed to clean up temp file {tmp_path}: {e}")

    logger.info(f"Results saved to: {output_file}")


if __name__ == "__main__":
    main()
