import os
import json
import time
import asyncio
import random
from typing import List, Dict, Any, Tuple, Optional
from openai import AsyncOpenAI
from pydantic import BaseModel, Field
from datetime import datetime
import re
import argparse

# Set random seed for reproducibility
random.seed(42)

# OpenAI API Configuration
API_CONFIG = {
    "MODEL_NAME": "gpt-4.1-2025-04-14",
    "MAX_RETRIES": 3,
    "RETRY_DELAY": 2,
    "BATCH_SIZE": 10,
    "TIMEOUT": 180,
    "TEMPERATURE": 0.25,
    "MAX_TOKENS": 16384,
    "ENABLE_BATCH_PROCESSING": False,  
    "BATCH_API_SIZE": 2  
}

# File configuration
INPUT_FILE = "data/input/stage1_input.json"
OUTPUT_FILE_BASE = "data/output/stage1/left.json"
ANOMALIES_FILE_BASE = "data/output/stage1/anomalies_left.json"

DEBUG = False

def debug_print(message):
    """Print debug messages when DEBUG mode is enabled"""
    if DEBUG:
        print(f"DEBUG: {message}")

# Timestamp utility functions
def generate_timestamp() -> str:
    """Generate timestamp string in format YYYYMMDD_HHMMSS"""
    return datetime.now().strftime("%Y%m%d_%H%M%S")

def create_timestamped_filename(base_name: str, extension: str = ".json") -> str:
    """Create filename with timestamp"""
    timestamp = generate_timestamp()
    name_without_ext = base_name.replace(extension, "")
    return f"{name_without_ext}_{timestamp}{extension}"

# Initialize the AsyncOpenAI client
openai_client = AsyncOpenAI(
    api_key=os.environ.get("OPENAI_API_KEY_IMAGE"),
    timeout=API_CONFIG["TIMEOUT"],
    max_retries=API_CONFIG["MAX_RETRIES"]
)

# Pydantic models for input/output validation
class MathInformation(BaseModel):
    """Schema for math information"""
    raw_math_information: str = Field(description="Original math statement")
    object: str = Field(description="The concrete or abstract object")
    math_value: str = Field(description="Mathematical value")
    semantic: str = Field(description="Contextual meaning")
    use_strategy: str = Field(description="Visual expression strategy for this math information")
    use_meta_description: str = Field(description="Meta description module to be triggered for this math information")

class TokenUsage(BaseModel):
    """Schema for API token usage information"""
    prompt_tokens: int = Field(description="Number of tokens in the prompt")
    completion_tokens: int = Field(description="Number of tokens in the completion")
    total_tokens: int = Field(description="Total number of tokens used")

class BatchMathExtractionInput(BaseModel):
    """Schema for batch math extraction input data"""
    questions: List[Dict[str, str]] = Field(description="List of questions with question_id and original_question")

class BatchMathExtractionOutput(BaseModel):
    """Schema for batch math extraction output data"""
    results: List[Dict[str, Any]] = Field(description="List of extraction results with question_id and scenes")

class ExtractedScene(BaseModel):
    """Schema for extracted scene data"""
    scene_id: int = Field(description="Scene identifier")
    interfere: str = Field(description="Interference type: perception, semantic, or none")
    scene_math_information: List[MathInformation] = Field(description="Math information for this scene")

class MathExtractionInput(BaseModel):
    """Schema for math extraction input data"""
    original_question: str = Field(description="Original math question")

class MathExtractionOutput(BaseModel):
    """Schema for math extraction output data"""
    scenes: List[ExtractedScene] = Field(description="Extracted scenes with math information")

class FinalMathExtractionOutput(BaseModel):
    """Schema for final output combining input and extraction results"""
    question_id: str = Field(description="Question identifier")
    original_question: str = Field(description="Original math question")
    math_ground_truth: str = Field(description="Math ground truth answer")
    scenes: List[ExtractedScene] = Field(description="Extracted scenes with math information")
    processing_status: str = Field(description="Processing status")
    token_usage: Optional[TokenUsage] = Field(description="API token usage information", default=None)

# Import prompts and examples from math_info_prompt.py
from step1_prompt import (
    PROMPT,
    example1_input, example1_output,
    example2_input, example2_output,
    weight_input, weight_output,
    example3_input, example3_output,
    example4_input, example4_output,
    example5_input, example5_output,
    example6_input, example6_output,
)

def transform_input_to_processing_format(item: Dict[str, Any]) -> Dict[str, Any]:
    """
    Transform input data to processing format
    
    Args:
        item: Single item from input file
    
    Returns:
        Transformed data for processing
    """
    return {
        "original_question": item.get("original_question", "")
    }

def fix_json_structure(response_text: str) -> str:
    """
    Fix common JSON structural issues
    
    Args:
        response_text: Raw API response text
        
    Returns:
        Fixed JSON string
    """
    # Remove markdown formatting
    if response_text.startswith('```json'):
        response_text = response_text[7:]
    if response_text.endswith('```'):
        response_text = response_text[:-3]
    response_text = response_text.strip()
    
    # Fix missing commas between object properties
    response_text = re.sub(r'"\s*\n\s*"', '",\n"', response_text)
    
    # Fix missing commas between array elements (objects)
    response_text = re.sub(r'}\s*\n\s*{', '},\n{', response_text)
    
    # Fix missing commas after closing brackets/braces before new properties
    response_text = re.sub(r']\s*\n\s*"', '],\n"', response_text)
    response_text = re.sub(r'}\s*\n\s*"', '},\n"', response_text)
    
    # Fix escape sequence issues
    response_text = re.sub(r'\\(?!["\\/bfnrt]|u[0-9a-fA-F]{4})', r'\\\\', response_text)
    response_text = response_text.replace('\\$', '$')
    
    return response_text

def attempt_json_repair(response_text: str, question_id: str) -> dict:
    """
    Attempt multiple strategies to repair and parse JSON
    
    Args:
        response_text: Raw API response text
        question_id: Question ID for debugging
        
    Returns:
        Parsed JSON data
        
    Raises:
        Exception if all repair attempts fail
    """
    original_text = response_text
    
    # Strategy 1: Basic cleaning
    try:
        cleaned = fix_json_structure(response_text)
        return json.loads(cleaned)
    except json.JSONDecodeError as e:
        debug_print(f"Strategy 1 failed for {question_id}: {e}")
    
    # Strategy 2: Try to extract just the JSON object part
    try:
        start_idx = response_text.find('{')
        if start_idx != -1:
            brace_count = 0
            end_idx = -1
            for i in range(start_idx, len(response_text)):
                if response_text[i] == '{':
                    brace_count += 1
                elif response_text[i] == '}':
                    brace_count -= 1
                    if brace_count == 0:
                        end_idx = i + 1
                        break
            
            if end_idx != -1:
                json_part = response_text[start_idx:end_idx]
                cleaned = fix_json_structure(json_part)
                return json.loads(cleaned)
    except json.JSONDecodeError as e:
        debug_print(f"Strategy 2 failed for {question_id}: {e}")
    
    # All strategies failed
    debug_print(f"All JSON repair strategies failed for {question_id}")
    debug_print(f"Original text: {original_text[:1000]}...")
    raise Exception(f"Unable to repair JSON for question {question_id} after multiple attempts")

async def retry_async_call(func, *args, **kwargs):
    """
    Enhanced retry function with exponential backoff and specific error handling
    
    Args:
        func: The async function to call
        *args: Positional arguments to pass to the function
        **kwargs: Keyword arguments to pass to the function
        
    Returns:
        The result of the function call
    """
    retries = 0
    while True:
        try:
            debug_print(f"Making async API call, attempt {retries + 1}")
            return await func(*args, **kwargs)
        except Exception as e:
            retries += 1
            error_msg = str(e).lower()
            
            # Check for specific error types
            if any(keyword in error_msg for keyword in [
                'timeout', 'ssl', 'connection', 'network', 'connect', 'proxy', 'dns', 'socket'
            ]):
                debug_print(f"Network error detected: {e}")
                if retries > API_CONFIG["MAX_RETRIES"]:
                    raise Exception(f"Network connection failed after {API_CONFIG['MAX_RETRIES']} retries: {e}")
            elif any(keyword in error_msg for keyword in ['rate limit', '429', 'quota', 'billing', 'limit_requests']):
                debug_print(f"Rate limit or quota error: {e}")
                if retries > API_CONFIG["MAX_RETRIES"]:
                    raise Exception(f"Rate limit/quota exceeded after {API_CONFIG['MAX_RETRIES']} retries: {e}")
            elif any(keyword in error_msg for keyword in ['401', 'unauthorized', 'api key']):
                raise Exception(f"Authentication error: {e}")
            else:
                if retries > API_CONFIG["MAX_RETRIES"]:
                    raise e
            
            # Progressive backoff with jitter
            base_wait = API_CONFIG["RETRY_DELAY"] * (2 ** (retries - 1))
            jitter = base_wait * 0.1 * (2 * time.time() - int(time.time()))
            wait_time = min(base_wait + jitter, 60)  # Max 60 seconds
            
            debug_print(f"Error: {e}. Retrying in {wait_time:.1f}s... (Attempt {retries}/{API_CONFIG['MAX_RETRIES']})")
            await asyncio.sleep(wait_time)

async def process_math_extraction_async(input_data: Dict[str, Any]) -> Dict[str, Any]:
    """
    Process math information extraction using GPT model
    
    Args:
        input_data: Math extraction input data
    
    Returns:
        Math extraction processing result
    """
    question_id = input_data.get('question_id', 'unknown')
    debug_print(f"Processing math extraction for question ID: {question_id}")
    
    try:
        # Create input for the prompt (exclude question_id from API call)
        prompt_input = {
            "original_question": input_data["original_question"]
        }
        
        # Validate input data
        processing_input = MathExtractionInput(**prompt_input)
        input_json = json.dumps(processing_input.model_dump(), ensure_ascii=False, indent=2)
        
        # Build system prompt
        system_prompt = PROMPT
        
        system_prompt += f"""

You must respond with a valid JSON object that matches this schema:
{MathExtractionOutput.model_json_schema()}

CRITICAL JSON FORMATTING REQUIREMENTS:
1. Use proper comma separation between all object properties
2. Use proper comma separation between all array elements
3. Ensure all strings are properly quoted with double quotes
4. Do not use trailing commas
5. Ensure proper nesting and bracket matching
6. Response must be ONLY valid JSON, no additional text or markdown

Your response must be valid JSON only, no additional text."""
        
        # Create multi-turn conversation messages for GPT with examples
        messages = [
            {
                "role": "system",
                "content": system_prompt
            }
        ]
        
        # Add examples
        examples = [
            (example1_input, example1_output),
            (example2_input, example2_output),
            (weight_input, weight_output),
            (example3_input, example3_output),
            (example4_input, example4_output),
            (example5_input, example5_output),
            (example6_input, example6_output),
        ]
        
        for example_input, example_output in examples:
            messages.extend([
                {
                    "role": "user",
                    "content": f"""Input:
                        {json.dumps(example_input, ensure_ascii=False, indent=2)}"""
                },
                {
                    "role": "assistant",
                    "content": json.dumps(example_output, ensure_ascii=False, indent=2)
                }
            ])
  
        
        # Add the actual processing request
        messages.append({
            "role": "user",
            "content": f"""Input:\n {input_json} \n################## \n Apply the same principles as the previous examples. Extract mathematical information, classify it according to the definitions, and determine appropriate meta modules. Do not limit your thinking or output length. Think exhaustively and respond with unlimited depth and detail. There is no word limit - continue expanding your answer as far as possible. Think Procedure must follow the global prompt rule."""
        })
        
        debug_print(f"Math extraction messages prepared with {len(messages)} turns")
        
        # Use OpenAI client for processing
        response = await retry_async_call(
            openai_client.chat.completions.create,
            model=API_CONFIG["MODEL_NAME"],
            messages=messages,
            temperature=API_CONFIG["TEMPERATURE"],
            max_tokens=API_CONFIG["MAX_TOKENS"]
        )
        
        # for item in messages:
        #     print(f" ================== \n{item['content']} \n")
        
        response_text = response.choices[0].message.content.strip()
        
        if not response_text or not response_text.strip():
            raise ValueError("Model returned empty response")
        
        # Parse the JSON response using enhanced repair function
        extraction_result = attempt_json_repair(response_text, question_id)
        
        # Extract token usage information from API response
        token_usage = None
        if hasattr(response, 'usage') and response.usage:
            token_usage = {
                "prompt_tokens": response.usage.prompt_tokens,
                "completion_tokens": response.usage.completion_tokens,
                "total_tokens": response.usage.total_tokens
            }
        
        # Validate output data
        MathExtractionOutput(**extraction_result)  # This will raise an error if validation fails
        
        debug_print(f"Math extraction completed for question ID: {question_id}")
        
        # Return both extraction result and token usage
        return {
            "extraction_result": extraction_result,
            "token_usage": token_usage
        }
        
    except Exception as e:
        print(f"Math extraction failed for question ID: {question_id}. Error: {e}")
        raise e

async def process_math_extraction_batch_async(batch_data: List[Dict[str, Any]]) -> Dict[str, Any]:
    """
    Process multiple math information extractions using GPT model in a single API call
    
    Args:
        batch_data: List of math extraction input data
    
    Returns:
        Batch math extraction processing result with individual results and token usage
    """
    batch_question_ids = [item.get('question_id', f'item-{i}') for i, item in enumerate(batch_data)]
    debug_print(f"Processing batch math extraction for question IDs: {batch_question_ids}")
    
    try:
        # Create batch input for the prompt
        questions_for_prompt = []
        for item in batch_data:
            questions_for_prompt.append({
                "question_id": item.get('question_id', 'unknown'),
                "original_question": item.get("original_question", "")
            })
        
        # Validate batch input data
        batch_input = BatchMathExtractionInput(questions=questions_for_prompt)
        input_json = json.dumps(batch_input.model_dump(), ensure_ascii=False, indent=2)
        
        # Build system prompt for batch processing
        system_prompt = PROMPT
        
        system_prompt += f"""

BATCH PROCESSING INSTRUCTIONS:
You will receive multiple questions at once. Process each question independently and return results for all questions.

You must respond with a valid JSON object that matches this schema:
{BatchMathExtractionOutput.model_json_schema()}

The "results" array should contain one result object for each input question, in the same order.
Each result object should have:
- question_id: The ID of the question being processed
- scenes: Array of extracted scenes following the same format as single question processing

CRITICAL JSON FORMATTING REQUIREMENTS:
1. Use proper comma separation between all object properties
2. Use proper comma separation between all array elements
3. Ensure all strings are properly quoted with double quotes
4. Do not use trailing commas
5. Ensure proper nesting and bracket matching
6. Response must be ONLY valid JSON, no additional text or markdown

Your response must be valid JSON only, no additional text."""
        
        # Create multi-turn conversation messages for GPT with examples
        messages = [
            {
                "role": "system",
                "content": system_prompt
            }
        ]
        
        # Add examples (modify to show batch format, use fewer examples to save tokens)
        examples = [
            (example1_input, example1_output),
            (example2_input, example2_output),
            (weight_input, weight_output),
            (example3_input, example3_output),
            (example4_input, example4_output),
            (example5_input, example5_output),
            (example6_input, example6_output),
        ]
        
        # Convert examples to batch format for demonstration
        for example_input, example_output in examples:
            batch_example_input = {
                "questions": [
                    {
                        "original_question": example_input.get("original_question", "")
                    }
                ]
            }
            batch_example_output = {
                "results": [
                    {
                        "scenes": example_output.get("scenes", [])
                    }
                ]
            }
            
            messages.extend([
                {
                    "role": "user",
                    "content": f"""Input:
                        {json.dumps(batch_example_input, ensure_ascii=False, indent=2)}"""
                },
                {
                    "role": "assistant",
                    "content": json.dumps(batch_example_output, ensure_ascii=False, indent=2)
                }
            ])
            
        
        # Add the actual batch processing request
        messages.append({
            "role": "user",
            "content": f"""Input:\n {input_json} \n################## \n Process all questions in this batch. Apply the same principles as the previous examples to each question. Extract mathematical information, classify it according to the definitions, and determine appropriate meta modules for each question. Return results for all questions in the results array. Do not limit your thinking or output length. Think exhaustively and respond with unlimited depth and detail for each question."""
        })
        
        debug_print(f"Batch math extraction messages prepared with {len(messages)} turns for {len(batch_data)} questions")
        
        print("============================================================================")
        print(messages)
        print("============================================================================")
        
        # Use OpenAI client for processing
        response = await retry_async_call(
            openai_client.chat.completions.create,
            model=API_CONFIG["MODEL_NAME"],
            messages=messages,
            temperature=API_CONFIG["TEMPERATURE"],
            max_tokens=API_CONFIG["MAX_TOKENS"]
        )
        
        response_text = response.choices[0].message.content.strip()
        
        if not response_text or not response_text.strip():
            raise ValueError("Model returned empty response")
        
        # Parse the JSON response using enhanced repair function
        batch_result = attempt_json_repair(response_text, f"batch-{'-'.join(map(str, batch_question_ids))}")
        
        # Validate output data
        BatchMathExtractionOutput(**batch_result)  # This will raise an error if validation fails
        
        # Extract token usage information from API response
        token_usage = None
        if hasattr(response, 'usage') and response.usage:
            token_usage = {
                "prompt_tokens": response.usage.prompt_tokens,
                "completion_tokens": response.usage.completion_tokens,
                "total_tokens": response.usage.total_tokens
            }
        
        debug_print(f"Batch math extraction completed for question IDs: {batch_question_ids}")
        
        # Return both batch result and token usage
        return {
            "batch_result": batch_result,
            "token_usage": token_usage,
            "batch_size": len(batch_data)
        }
        
    except Exception as e:
        print(f"Batch math extraction failed for question IDs: {batch_question_ids}. Error: {e}")
        raise e

async def process_batch_with_batch_api_async(batch_data: List[Dict[str, Any]], semaphore: asyncio.Semaphore) -> List[Dict[str, Any]]:
    """
    Process a batch of items using batch API call instead of individual calls
    
    Args:
        batch_data: List of items to process
        semaphore: Semaphore to control concurrency
    
    Returns:
        List of processed results
    """
    debug_print(f"Processing batch of {len(batch_data)} items with single batch API call")
    
    async with semaphore:
        try:
            # Split large batches into smaller sub-batches for API processing
            api_batch_size = API_CONFIG["BATCH_API_SIZE"]
            all_results = []
            
            for i in range(0, len(batch_data), api_batch_size):
                sub_batch = batch_data[i:i+api_batch_size]
                
                # Process sub-batch with single API call
                batch_processing_result = await process_math_extraction_batch_async(sub_batch)
                batch_result = batch_processing_result.get("batch_result", {})
                token_usage = batch_processing_result.get("token_usage")
                batch_size = batch_processing_result.get("batch_size", len(sub_batch))
                
                # Extract individual results from batch response
                batch_results = batch_result.get("results", [])
                
                # Create final results for each item in sub-batch
                for j, item in enumerate(sub_batch):
                    question_id = item.get('question_id', f'item-{i+j}')
                    
                    # Find corresponding result in batch response
                    item_result = None
                    for result in batch_results:
                        if result.get("question_id") == question_id:
                            item_result = result
                            break
                    
                    if item_result:
                        # Create final result with scenes from batch response
                        final_result = {
                            "question_id": question_id,
                            "original_question": item.get("original_question", ""),
                            "math_ground_truth": item.get("math_ground_truth", ""),
                            "scenes": item_result.get("scenes", []),
                            "processing_status": "success"
                        }
                        
                        # Distribute token usage proportionally across items in sub-batch
                        if token_usage and batch_size > 0:
                            proportional_token_usage = {
                                "prompt_tokens": int(token_usage.get("prompt_tokens", 0) / batch_size),
                                "completion_tokens": int(token_usage.get("completion_tokens", 0) / batch_size),
                                "total_tokens": int(token_usage.get("total_tokens", 0) / batch_size)
                            }
                            final_result["token_usage"] = proportional_token_usage
                        
                        all_results.append(final_result)
                    else:
                        # Create error result if not found in batch response
                        error_result = {
                            "question_id": question_id,
                            "original_question": item.get("original_question", ""),
                            "math_ground_truth": item.get("math_ground_truth", ""),
                            "scenes": [],
                            "processing_status": "failed_batch_missing_result",
                            "processing_error": f"Result not found in batch response for question_id: {question_id}"
                        }
                        all_results.append(error_result)
            
            return all_results
            
        except Exception as e:
            print(f"Error processing batch with batch API: {e}")
            # Create error results for all items in batch
            error_results = []
            for i, item in enumerate(batch_data):
                question_id = item.get('question_id', f'item-{i}')
                error_result = {
                    "question_id": question_id,
                    "original_question": item.get("original_question", ""),
                    "math_ground_truth": item.get("math_ground_truth", ""),
                    "scenes": [],
                    "processing_status": "failed_batch_api_exception",
                    "processing_error": str(e)
                }
                error_results.append(error_result)
            return error_results

async def process_single_item(item: Dict[str, Any]) -> Dict[str, Any]:
    """
    Process a single item through math extraction
    
    Args:
        item: Single item from input file
    
    Returns:
        Processed item with extraction results
    """
    question_id = item.get('question_id', 'unknown')
    
    try:
        # Transform input data for processing
        processing_input_data = transform_input_to_processing_format(item)
        processing_input_data['question_id'] = question_id  # Preserve question_id for tracking
        
        # Process math extraction
        processing_result = await process_math_extraction_async(processing_input_data)
        extraction_result = processing_result.get("extraction_result", {})
        token_usage = processing_result.get("token_usage")
        
        # Create final result 
        final_result = {
            "question_id": question_id,
            "original_question": item.get("original_question", ""),
            "math_ground_truth": item.get("math_ground_truth", ""),
            "scenes": extraction_result.get("scenes", []),
            "processing_status": "success"
        }
        
        # Add token usage if available
        if token_usage:
            final_result["token_usage"] = token_usage
        
        debug_print(f"Math extraction processing completed for question ID: {question_id}")
        return final_result
        
    except Exception as e:
        print(f"Math extraction processing failed for question ID: {question_id}. Error: {e}")
        error_result = {
            "question_id": question_id,
            "original_question": item.get("original_question", ""),
            "math_ground_truth": item.get("math_ground_truth", ""),
            "scenes": [],
            "processing_status": "failed",
            "processing_error": str(e)
        }
        return error_result

async def process_batch_async(batch_data: List[Dict[str, Any]], semaphore: asyncio.Semaphore) -> List[Dict[str, Any]]:
    """
    Process a batch of items using async parallel processing
    
    Args:
        batch_data: List of items to process
        semaphore: Semaphore to control concurrency
    
    Returns:
        List of processed results
    """
    debug_print(f"Processing batch of {len(batch_data)} items in parallel")
    
    # Create tasks for parallel processing
    tasks = []
    for item in batch_data:
        async def process_with_semaphore(item_data):
            async with semaphore:
                return await process_single_item(item_data)
        
        tasks.append(process_with_semaphore(item))
    
    # Execute all tasks in parallel
    results = await asyncio.gather(*tasks, return_exceptions=True)
    
    # Convert exceptions to error results
    processed_results = []
    for i, result in enumerate(results):
        if isinstance(result, Exception):
            print(f"Error processing item {batch_data[i].get('question_id', f'item-{i}')}: {result}")
            error_item = {
                "question_id": batch_data[i].get('question_id', f'item-{i}'),
                "original_question": batch_data[i].get("original_question", ""),
                "math_ground_truth": batch_data[i].get("math_ground_truth", ""),
                "scenes": [],
                "processing_status": "failed_parallel_processing",
                "processing_error": str(result)
            }
            processed_results.append(error_item)
        else:
            processed_results.append(result)
    
    return processed_results

async def test_connection():
    """Test OpenAI API connection"""
    print("Testing OpenAI API connection...")
    
    if not os.environ.get("OPENAI_API_KEY_IMAGE"):
        print("Warning: OPENAI_API_KEY_IMAGE environment variable not set")
        return False
    
    try:
        test_response = await openai_client.chat.completions.create(
            model=API_CONFIG["MODEL_NAME"],
            messages=[
                {"role": "user", "content": "Hello! Please respond with a simple JSON object containing a 'status' field with value 'ok'."}
            ],
            max_tokens=100,
            temperature=0
        )
        
        response_content = test_response.choices[0].message.content
        print(f"OpenAI API connection successful")
        print(f"Test response: {response_content[:200]}...")
        return True
    except Exception as e:
        print(f"OpenAI API connection failed: {e}")
        return False

async def main_async(input_file: str, output_file: str = None, anomalies_file: str = None, 
                    batch_size: int = API_CONFIG["BATCH_SIZE"]):
    """
    Main function for math extraction processing
    
    Args:
        input_file: Input file path
        output_file: Output file path for successful results
        anomalies_file: Output file path for failed results
        batch_size: Number of items to process in parallel per batch
    """
    debug_print(f"Starting math extraction processing")
    
    # Test API connection
    connection_ok = await test_connection()
    if not connection_ok:
        print("Cannot proceed without valid OpenAI API connection")
        return
    
    # Generate timestamped filenames if not provided
    if output_file is None:
        output_file = create_timestamped_filename(OUTPUT_FILE_BASE + ".json")
    if anomalies_file is None:
        anomalies_file = create_timestamped_filename(ANOMALIES_FILE_BASE + ".json")
    
    # Create output directory if needed
    output_dir = os.path.dirname(output_file) if os.path.dirname(output_file) else "."
    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)
    
    print(f"\nLoading data from {input_file}")
    print(f"Output will be saved to: {output_file}")
    if anomalies_file:
        print(f"Anomalies will be saved to: {anomalies_file}")
    
    try:
        with open(input_file, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        print(f"Loaded {len(data)} items from input file.")
            
    except Exception as e:
        print(f"Error loading input file: {e}")
        return
    
    start_time = time.time()
    total_items = len(data)
    all_results = []
    successful_results = []
    anomalies = []
    
    # Create semaphore for controlling concurrency
    semaphore = asyncio.Semaphore(batch_size)
    
    # Determine processing mode
    use_batch_api = API_CONFIG["ENABLE_BATCH_PROCESSING"]
    processing_mode = "Batch API Processing" if use_batch_api else "Individual API Processing"
    
    print(f"\nStarting Math Information Extraction Processing")
    print(f"   Model: {API_CONFIG['MODEL_NAME']} (GPT)")
    print(f"   Processing Mode: {processing_mode}")
    print(f"   Batch Size: {batch_size} (Parallel batches)")
    if use_batch_api:
        print(f"   API Batch Size: {API_CONFIG['BATCH_API_SIZE']} (Questions per API call)")
    
    # Process data in batches using appropriate processing function
    for i in range(0, total_items, batch_size):
        batch = data[i:i+batch_size]
        batch_start = time.time()
        total_batches = (total_items + batch_size - 1) // batch_size
        current_batch = i//batch_size + 1
        
        print(f"\nProcessing batch {current_batch}/{total_batches} ({len(batch)} items) in parallel")
        
        try:
            # Choose processing function based on mode
            if use_batch_api:
                batch_results = await process_batch_with_batch_api_async(batch, semaphore)
            else:
                batch_results = await process_batch_async(batch, semaphore)
            
            all_results.extend(batch_results)
            
            # Separate successful and failed results
            for result in batch_results:
                if result.get("processing_status") == "success":
                    successful_results.append(result)
                else:
                    anomalies.append(result)
                    
        except Exception as e:
            print(f"Error processing batch {current_batch}: {e}")
            # Add failed items to anomalies only
            for j, item in enumerate(batch):
                error_item = {
                    "question_id": item.get('question_id', f'item-{i+j}'),
                    "original_question": item.get("original_question", ""),
                    "math_ground_truth": item.get("math_ground_truth", ""),
                    "scenes": [],
                    "processing_status": "failed_batch_exception",
                    "processing_error": str(e)
                }
                all_results.append(error_item)
                anomalies.append(error_item)
        
        # Display progress
        batch_time = time.time() - batch_start
        elapsed = time.time() - start_time
        items_processed = len(all_results)
        successful_count = len(successful_results)
        anomalies_count = len(anomalies)
        progress_percent = items_processed/total_items*100 if total_items > 0 else 0
        items_per_second = items_processed / elapsed if elapsed > 0 else 0
        
        if items_processed > 0:
            estimated_total = (elapsed / items_processed) * total_items
            estimated_remaining = estimated_total - elapsed
        else:
            estimated_remaining = float('inf')
        
        print(f"Progress: {items_processed}/{total_items} items ({progress_percent:.1f}%)")
        print(f"Successful: {successful_count}, Anomalies: {anomalies_count}")
        print(f"Speed: {items_per_second:.2f} items/sec ({processing_mode.lower()})")
        print(f"Time elapsed: {elapsed:.1f}s, estimated remaining: {'inf' if estimated_remaining == float('inf') else f'{estimated_remaining:.1f}s'}")
        print(f"Batch processing time: {batch_time:.1f}s ({batch_time/len(batch):.1f}s per item average)")
    
    # Save successful results only to output file
    print(f"\nSaving successful results to {output_file}")
    try:
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(successful_results, f, ensure_ascii=False, indent=2)
        print(f"{len(successful_results)} successful results saved.")
    except Exception as e:
        print(f"Error saving results: {e}")
    
    # Save anomalies to separate file
    if anomalies:
        print(f"\nSaving anomalies to {anomalies_file}")
        try:
            with open(anomalies_file, 'w', encoding='utf-8') as f:
                json.dump(anomalies, f, ensure_ascii=False, indent=2)
            print(f"{len(anomalies)} anomalies saved for review.")
        except Exception as e:
            print(f"Error saving anomalies: {e}")
    else:
        print(f"No anomalies detected.")
    
    # Print summary
    print_summary_statistics(successful_results, anomalies, time.time() - start_time, total_items)

def print_summary_statistics(successful_results: List[Dict], anomalies: List[Dict], 
                           total_time: float, total_items: int):
    """
    Print comprehensive summary statistics for parallel processing
    
    Args:
        successful_results: List of successfully processed items
        anomalies: List of failed items
        total_time: Total processing time in seconds
        total_items: Total number of input items
    """
    successful_count = len(successful_results)
    failed_count = len(anomalies)
    
    # Determine processing mode description
    processing_mode = "Batch API processing" if API_CONFIG["ENABLE_BATCH_PROCESSING"] else "Individual API processing"
    
    print("\n============================== Math Information Extraction Processing Summary ==============================")
    print(f"Model: {API_CONFIG['MODEL_NAME']} (GPT)")
    print(f"Total items loaded: {total_items}")
    print(f"Successfully processed items: {successful_count}")
    print(f"Items with errors (anomalies): {failed_count}")
    print(f"Success rate: {successful_count/total_items*100:.1f}%" if total_items > 0 else "Success rate: 0%")
    print(f"Total time: {total_time:.1f}s")
    print(f"Average time per item: {total_time/total_items:.1f}s" if total_items > 0 else "Average time per item: N/A")
    print(f"Processing mode: {processing_mode}, batch size: {API_CONFIG['BATCH_SIZE']}")
    if API_CONFIG["ENABLE_BATCH_PROCESSING"]:
        print(f"API batch size: {API_CONFIG['BATCH_API_SIZE']} questions per API call")
    
    # Analyze scene and interference statistics
    if successful_results:
        print(f"\nScene Analysis:")
        scene_count_stats = {}
        interfere_stats = {}
        use_strategy_stats = {}
        use_meta_description_stats = {}
        
        # Token usage statistics
        total_prompt_tokens = 0
        total_completion_tokens = 0
        total_tokens_used = 0
        items_with_token_info = 0
        
        for item in successful_results:
            scenes = item.get("scenes", [])
            
            # Token usage calculation
            token_usage = item.get("token_usage")
            if token_usage:
                total_prompt_tokens += token_usage.get("prompt_tokens", 0)
                total_completion_tokens += token_usage.get("completion_tokens", 0)
                total_tokens_used += token_usage.get("total_tokens", 0)
                items_with_token_info += 1
            
            # Scene count statistics
            scene_count = len(scenes)
            scene_count_stats[scene_count] = scene_count_stats.get(scene_count, 0) + 1
            
            # Interference and strategy statistics
            for scene in scenes:
                interfere_type = scene.get("interfere", "none")
                interfere_stats[interfere_type] = interfere_stats.get(interfere_type, 0) + 1
                
                for math_info in scene.get("scene_math_information", []):
                    strategy = math_info.get("use_strategy", "none")
                    use_strategy_stats[strategy] = use_strategy_stats.get(strategy, 0) + 1
                    
                    meta_desc = math_info.get("use_meta_description", "none")
                    use_meta_description_stats[meta_desc] = use_meta_description_stats.get(meta_desc, 0) + 1
        
        print(f"\n   Scene count distribution:")
        for count, num_items in sorted(scene_count_stats.items()):
            percentage = num_items / successful_count * 100 if successful_count > 0 else 0
            print(f"   {count} scenes: {num_items} items ({percentage:.1f}%)")
        
        print(f"\n   Interference type distribution:")
        total_scenes = sum(interfere_stats.values())
        for interfere_type, count in sorted(interfere_stats.items()):
            percentage = count / total_scenes * 100 if total_scenes > 0 else 0
            print(f"   {interfere_type}: {count} scenes ({percentage:.1f}%)")
        
        print(f"\n   Use strategy distribution:")
        total_strategies = sum(use_strategy_stats.values())
        for strategy, count in sorted(use_strategy_stats.items()):
            percentage = count / total_strategies * 100 if total_strategies > 0 else 0
            print(f"   {strategy}: {count} math info ({percentage:.1f}%)")
        
        print(f"\n   Meta description distribution:")
        total_meta = sum(use_meta_description_stats.values())
        for meta_desc, count in sorted(use_meta_description_stats.items()):
            percentage = count / total_meta * 100 if total_meta > 0 else 0
            print(f"   {meta_desc}: {count} math info ({percentage:.1f}%)")
        
        # Print token usage statistics
        if items_with_token_info > 0:
            avg_prompt_tokens = total_prompt_tokens / items_with_token_info
            avg_completion_tokens = total_completion_tokens / items_with_token_info
            avg_total_tokens = total_tokens_used / items_with_token_info
            
            print(f"\n   Token Usage Statistics:")
            print(f"   Items with token info: {items_with_token_info}/{successful_count}")
            print(f"   Total prompt tokens: {total_prompt_tokens}")
            print(f"   Total completion tokens: {total_completion_tokens}")
            print(f"   Total tokens used: {total_tokens_used}")
            print(f"   Average prompt tokens per item: {avg_prompt_tokens:.1f}")
            print(f"   Average completion tokens per item: {avg_completion_tokens:.1f}")
            print(f"   Average total tokens per item: {avg_total_tokens:.1f}")
    
    print("==================================================================================================")

def main(input_file: str, output_file: str = None, anomalies_file: str = None, 
         batch_size: int = API_CONFIG["BATCH_SIZE"]):
    """Synchronous wrapper for main_async function"""
    return asyncio.run(main_async(input_file, output_file, anomalies_file, batch_size))

def test_data_transformation():
    """Test the data transformation functions"""
    print("Testing data transformation...")
    
    # Sample test data from input1.json format
    test_item = {
        "question_id": "test_1",
        "original_question": "Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?",
        "math_ground_truth": "18"
    }
    
    # Test processing transformation
    processing_input = transform_input_to_processing_format(test_item)
    print(f"   Processing input: {processing_input}")
    print(f"   Original question preserved: {processing_input.get('original_question', '') != ''}")
    
    print(f"   Data transformation test: PASS")
    
    return processing_input

def run_tests():
    """Run all test functions"""
    print("Running Tests for Math Information Extraction Processing")
    print("=" * 60)
    
    test_data_transformation()
    
    print("=" * 60)
    print("All tests completed")

def parse_arguments():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser(description='Math Information Extraction Processing Tool - GPT Version')
    
    parser.add_argument('--input', '-i', type=str, default=INPUT_FILE,
                       help=f'Input file path (default: {INPUT_FILE})')
    
    parser.add_argument('--output', '-o', type=str, default=None,
                       help='Output file path (default: auto-generated with timestamp)')
    
    parser.add_argument('--anomalies', '-a', type=str, default=None,
                       help='Anomalies file path (default: auto-generated with timestamp)')
    
    parser.add_argument('--batch-size', '-b', type=int, default=API_CONFIG["BATCH_SIZE"],
                       help=f'Batch size for parallel processing (default: {API_CONFIG["BATCH_SIZE"]})')
    
    parser.add_argument('--batch', action='store_true',
                       help='Enable batch API processing mode (process multiple questions per API call)')
    
    parser.add_argument('--api-batch-size', type=int, default=API_CONFIG["BATCH_API_SIZE"],
                       help=f'Number of questions per batch API call when batch mode is enabled (default: {API_CONFIG["BATCH_API_SIZE"]})')
    
    parser.add_argument('--debug', '-d', action='store_true',
                       help='Enable debug mode')
    
    parser.add_argument('--test', '-t', action='store_true',
                       help='Run tests only')
    
    return parser.parse_args()

if __name__ == "__main__":
    print("Math Information Extraction Processing Tool - GPT Version")
    print("=" * 80)
    print("Features:")
    print("   - Math information extraction and classification (GPT)")
    print("   - Meta module determination based on semantic analysis")
    print("   - Async parallel processing for maximum efficiency")
    print("   - Individual API calls or Batch API calls (--batch)")
    print("   - Comprehensive scene and interference analysis")
    print("   - Combines original data with extracted math information")
    print("")
    
    # Parse command line arguments
    args = parse_arguments()
    
    # Set debug mode
    DEBUG = args.debug
    
    # Configure batch processing mode
    API_CONFIG["ENABLE_BATCH_PROCESSING"] = args.batch
    API_CONFIG["BATCH_API_SIZE"] = args.api_batch_size
    
    # Run tests if requested
    if args.test:
        run_tests()
        exit(0)
    
    # Run tests first
    run_tests()
    print("")
    
    # Check if environment is set up correctly
    if not os.getenv("OPENAI_API_KEY_IMAGE"):
        print("Warning: OPENAI_API_KEY_IMAGE environment variable not set")
        print("   Please set your OpenAI API key before running processing")
        print("")
    
    # Generate timestamped output files if not provided
    if args.output is None:
        args.output = create_timestamped_filename(OUTPUT_FILE_BASE + ".json")
    
    if args.anomalies is None:
        args.anomalies = create_timestamped_filename(ANOMALIES_FILE_BASE + ".json")
    
    print(f"Configuration:")
    print(f"   Input file: {args.input}")
    print(f"   Batch size: {args.batch_size}")
    print(f"   Batch API mode: {'Enabled' if args.batch else 'Disabled'}")
    if args.batch:
        print(f"   API batch size: {args.api_batch_size}")
    print(f"   Output file: {args.output}")
    print(f"   Anomalies file: {args.anomalies}")
    print(f"   Debug mode: {args.debug}")
    print("")
    
    # Check if input file exists
    if not os.path.exists(args.input):
        print(f"Input file not found: {args.input}")
        print("   Please ensure the input file exists before running")
        exit(1)
    else:
        print(f"Input file found: {args.input}")
        processing_mode = "batch API" if args.batch else "individual API"
        print(f"Starting processing with async parallel execution ({processing_mode} calls)...")
        
        main(args.input, args.output, args.anomalies, args.batch_size)
