import os
import json
import time
import asyncio
import random
from typing import List, Dict, Any, Tuple, Optional
from openai import AsyncOpenAI
from pydantic import BaseModel, Field
from datetime import datetime
import re
import argparse

# Set random seed for reproducibility
random.seed(42)

# OpenAI API Configuration
API_CONFIG = {
    "MODEL_NAME": "gpt-4.1-2025-04-14",
    "MAX_RETRIES": 3,
    "RETRY_DELAY": 2,
    "BATCH_SIZE": 10,
    "TIMEOUT": 180,
    "TEMPERATURE": 0.2,
    "MAX_TOKENS": 16384
}

# File configuration
INPUT_FILE = "stage1 output path"
OUTPUT_FILE_BASE = "stage2 output path"
ANOMALIES_FILE_BASE = "stage2 output path"

DEBUG = False

def debug_print(message):
    """Print debug messages when DEBUG mode is enabled"""
    if DEBUG:
        print(f"DEBUG: {message}")

# Timestamp utility functions
def generate_timestamp() -> str:
    """Generate timestamp string in format YYYYMMDD_HHMMSS"""
    return datetime.now().strftime("%Y%m%d_%H%M%S")

def create_timestamped_filename(base_name: str, extension: str = ".json") -> str:
    """Create filename with timestamp"""
    timestamp = generate_timestamp()
    name_without_ext = base_name.replace(extension, "")
    return f"{name_without_ext}_{timestamp}{extension}"

# Initialize the AsyncOpenAI client
openai_client = AsyncOpenAI(
    api_key=os.environ.get("OPENAI_API_KEY_IMAGE"),
    timeout=API_CONFIG["TIMEOUT"],
    max_retries=API_CONFIG["MAX_RETRIES"]
)

# Pydantic models for input/output validation
class SceneMathInformation(BaseModel):
    """Schema for scene math information"""
    object: str = Field(description="The concrete or abstract object")
    math_value: str = Field(description="Mathematical value")
    semantic: str = Field(description="Contextual meaning")
    use_strategy: str = Field(description="Visual expression strategy for this math information")
    use_meta_description: str = Field(description="Meta description module to be triggered for this math information")

class InputScene(BaseModel):
    """Schema for input scene data"""
    scene_id: int = Field(description="Scene identifier")
    scene_math_information: List[SceneMathInformation] = Field(description="Math information for this scene")

class ProcessingInput(BaseModel):
    """Schema for processing input data"""
    scenes: List[InputScene] = Field(description="List of scenes with math information")

class GeneratedScene(BaseModel):
    """Schema for generated scene output"""
    scene_id: int = Field(description="Scene identifier")
    object: str = Field(description="Detailed list of all visible objects and text elements")
    composition: str = Field(description="Spatial layout description")
    action: str = Field(description="Semantic action or gesture linking math information")

class ProcessingOutput(BaseModel):
    """Schema for processing output data"""
    scenes: List[GeneratedScene] = Field(description="Generated scene descriptions")
    
class TokenUsage(BaseModel):
    """Schema for API token usage information"""
    prompt_tokens: int = Field(description="Number of tokens in the prompt")
    completion_tokens: int = Field(description="Number of tokens in the completion")
    total_tokens: int = Field(description="Total number of tokens used")
    
class CombinedTokenUsage(BaseModel):
    """Schema for combined token usage from multiple stages"""
    stage1_tokens: Optional[TokenUsage] = Field(description="Token usage from stage 1 (math extraction)", default=None)
    stage2_tokens: Optional[TokenUsage] = Field(description="Token usage from stage 2 (scene generation)", default=None)
    total_tokens: Optional[TokenUsage] = Field(description="Combined total token usage", default=None)


class FinalOutput(BaseModel):
    """Schema for final output combining input and generated descriptions"""
    question_id: str = Field(description="Question identifier")
    original_question: str = Field(description="Original math question")
    math_ground_truth: str = Field(description="Math ground truth answer")
    scenes: List[Dict[str, Any]] = Field(description="Combined scenes with math info and descriptions")
    processing_status: str = Field(description="Processing status")
    token_usage: Optional[CombinedTokenUsage] = Field(description="Combined token usage from all stages", default=None)


class BatchSceneInput(BaseModel):
    """Schema for batch scene generation input data"""
    questions: List[Dict[str, Any]] = Field(description="List of questions with scenes and math information")

class BatchSceneOutput(BaseModel):
    """Schema for batch scene generation output data"""
    results: List[Dict[str, Any]] = Field(description="List of scene generation results with question_id and scenes")

# Import prompts and examples from new_prompt.py
from step2_prompt import (
    global_strategy,
    process_prompt,
    example1_input, example1_output,
    object_measurement_meta_description, measurement_input, measurement_output, measurement_input2, measurement_output2,
    day_meta_description, calendar_day_example_input, calendar_day_example_output,
    time_span_meta_description, time_span_input, time_span_output,
    weight_meta_description, weight_and_time_input, weight_and_time_output,
    icon_ratio_meta_description, icon_ratio_input, icon_ratio_output,
    graph_ratio_meta_description, graph_ratio_input, graph_ratio_output, graph_ratio_cot,
    year_meta_description,
    month_meta_description,
    week_meta_description,
    distance_between_locations_meta_description,
    cross_scene_clock_meta_description,
    dashboard_meta_description, dashboard_input,  dashboard_output,
    
)

def transform_input_to_processing_format(item: Dict[str, Any]) -> Dict[str, Any]:
    """
    Transform input data to processing format
    
    Args:
        item: Single item from input file
    
    Returns:
        Transformed data for processing
    """
    scenes = item.get("scenes", [])
    
    transformed_scenes = []
    for scene in scenes:
        scene_math_info = scene.get("scene_math_information", [])
        valid_math_info = []
        
        for math_info in scene_math_info:
            # Exclude raw_math_information field when sending to model
            valid_math_info.append({
                "object": math_info.get("object", ""),
                "math_value": math_info.get("math_value", ""),
                "semantic": math_info.get("semantic", ""),
                "use_strategy": math_info.get("use_strategy", ""),
                "use_meta_description": math_info.get("use_meta_description", "")
            })
        
        if valid_math_info:  # Only include scenes with valid math information
            transformed_scene = {
                "scene_id": scene.get("scene_id"),
                "scene_math_information": valid_math_info
                # Note: Deliberately exclude 'interfere' field from processing input
            }
            transformed_scenes.append(transformed_scene)
    
    return {
        "scenes": transformed_scenes
    }

def determine_meta_modules(input_item: Dict[str, Any]) -> List[str]:
    """
    Determine which meta description modules to load based on input data
    
    Args:
        input_item: Input item containing meta_description_modules_to_load field
    
    Returns:
        List of meta description module names to load
    """
    # Get modules directly from the input data
    modules = set()
    scenes_meta = input_item.get("scenes", [])
    for i in scenes_meta:
        for j in i.get("scene_math_information", []):
            if j.get("use_meta_description") != "none":
                modules.add(j.get("use_meta_description"))
    
    # Validate that modules are from the allowed list
    valid_modules = [
        "year_meta_description",
        "month_meta_description",
        "week_meta_description",
        "day_meta_description",
        "distance_between_locations_meta_description",
        "object_measurement_meta_description",
        "time_span_meta_description",
        "weight_meta_description",
        "icon_ratio_meta_description",
        "graph_ratio_meta_description",
        "cross_scene_clock_meta_description",
        "dashboard_meta_description"
    ]
    
    # Filter to only include valid modules
    filtered_modules = [module for module in modules if module in valid_modules]
    
    return filtered_modules

def get_meta_description_content(modules: List[str]) -> str:
    """
    Get meta description content based on modules to load
    
    Args:
        modules: List of meta description module names
    
    Returns:
        Combined meta description content
    """
    meta_content = []
    
    module_mapping = {
        "year_meta_description": year_meta_description,
        "month_meta_description": month_meta_description,
        "week_meta_description": week_meta_description,
        "day_meta_description": day_meta_description,
        "distance_between_locations_meta_description": distance_between_locations_meta_description,
        "object_measurement_meta_description": object_measurement_meta_description,
        "time_span_meta_description": time_span_meta_description,
        "weight_meta_description": weight_meta_description,
        "icon_ratio_meta_description": icon_ratio_meta_description,
        "graph_ratio_meta_description": graph_ratio_meta_description,
        "cross_scene_clock_meta_description": cross_scene_clock_meta_description,
        "dashboard_meta_description": dashboard_meta_description
    }
    
    for module in modules:
        if module in module_mapping:
            content = module_mapping[module]
            if isinstance(content, dict):
                # Handle dict format (like distance_between_locations_meta_description)
                meta_content.append(f"## {module}\n{json.dumps(content, indent=2)}")
            else:
                # Handle string format
                meta_content.append(f"## {module}\n{content}")
    
    return "\n\n".join(meta_content)

def get_examples(modules: List[str]) -> List[Tuple[Dict, Dict]]:
    """
    Get examples based on loaded modules
    
    Args:
        modules: List of meta description module names
    
    Returns:
        List of (input, output) example pairs
    """
    examples = []
    cot = None
    
    # Always include basic example
    examples.append((example1_input, example1_output))
    examples.append((dashboard_input, dashboard_output))
    examples.append((calendar_day_example_input, calendar_day_example_output))
 
    # Add module-specific examples
    
    if "object_measurement_meta_description" in modules:
        examples.append((measurement_input, measurement_output))
        examples.append((measurement_input2, measurement_output2))
    
    # if "day_meta_description" in modules:
    #     examples.append((calendar_day_example_input, calendar_day_example_output))
    
    if "time_span_meta_description" in modules:
        examples.append((weight_and_time_input, weight_and_time_output))
        examples.append((time_span_input, time_span_output))
    
    if "weight_meta_description" in modules:
        examples.append((weight_and_time_input, weight_and_time_output))
        
    if "icon_ratio_meta_description" in modules:
        examples.append((icon_ratio_input, icon_ratio_output))
        
    if "graph_ratio_meta_description" in modules:
        examples.append((graph_ratio_input, graph_ratio_output))
        # cot = graph_ratio_cot
        
    
    return examples, cot

def fix_json_structure(response_text: str) -> str:
    """
    Fix common JSON structural issues
    
    Args:
        response_text: Raw API response text
        
    Returns:
        Fixed JSON string
    """
    # Remove markdown formatting
    if response_text.startswith('```json'):
        response_text = response_text[7:]
    if response_text.endswith('```'):
        response_text = response_text[:-3]
    response_text = response_text.strip()
    
    # Fix missing commas between object properties
    response_text = re.sub(r'"\s*\n\s*"', '",\n"', response_text)
    
    # Fix missing commas between array elements (objects)
    response_text = re.sub(r'}\s*\n\s*{', '},\n{', response_text)
    
    # Fix missing commas after closing brackets/braces before new properties
    response_text = re.sub(r']\s*\n\s*"', '],\n"', response_text)
    response_text = re.sub(r'}\s*\n\s*"', '},\n"', response_text)
    
    # Fix escape sequence issues
    response_text = re.sub(r'\\(?!["\\/bfnrt]|u[0-9a-fA-F]{4})', r'\\\\', response_text)
    response_text = response_text.replace('\\$', '$')
    
    return response_text

def attempt_json_repair(response_text: str, question_id: str) -> dict:
    """
    Attempt multiple strategies to repair and parse JSON
    
    Args:
        response_text: Raw API response text
        question_id: Question ID for debugging
        
    Returns:
        Parsed JSON data
        
    Raises:
        Exception if all repair attempts fail
    """
    original_text = response_text
    
    # Strategy 1: Basic cleaning
    try:
        cleaned = fix_json_structure(response_text)
        return json.loads(cleaned)
    except json.JSONDecodeError as e:
        debug_print(f"Strategy 1 failed for {question_id}: {e}")
    
    # Strategy 2: Try to extract just the JSON object part
    try:
        start_idx = response_text.find('{')
        if start_idx != -1:
            brace_count = 0
            end_idx = -1
            for i in range(start_idx, len(response_text)):
                if response_text[i] == '{':
                    brace_count += 1
                elif response_text[i] == '}':
                    brace_count -= 1
                    if brace_count == 0:
                        end_idx = i + 1
                        break
            
            if end_idx != -1:
                json_part = response_text[start_idx:end_idx]
                cleaned = fix_json_structure(json_part)
                return json.loads(cleaned)
    except json.JSONDecodeError as e:
        debug_print(f"Strategy 2 failed for {question_id}: {e}")
    
    # All strategies failed
    debug_print(f"All JSON repair strategies failed for {question_id}")
    debug_print(f"Original text: {original_text[:1000]}...")
    raise Exception(f"Unable to repair JSON for question {question_id} after multiple attempts")

async def retry_async_call(func, *args, **kwargs):
    """
    Enhanced retry function with exponential backoff and specific error handling
    
    Args:
        func: The async function to call
        *args: Positional arguments to pass to the function
        **kwargs: Keyword arguments to pass to the function
        
    Returns:
        The result of the function call
    """
    retries = 0
    while True:
        try:
            debug_print(f"Making async API call, attempt {retries + 1}")
            return await func(*args, **kwargs)
        except Exception as e:
            retries += 1
            error_msg = str(e).lower()
            
            # Check for specific error types
            if any(keyword in error_msg for keyword in [
                'timeout', 'ssl', 'connection', 'network', 'connect', 'proxy', 'dns', 'socket'
            ]):
                debug_print(f"Network error detected: {e}")
                if retries > API_CONFIG["MAX_RETRIES"]:
                    raise Exception(f"Network connection failed after {API_CONFIG['MAX_RETRIES']} retries: {e}")
            elif any(keyword in error_msg for keyword in ['rate limit', '429', 'quota', 'billing', 'limit_requests']):
                debug_print(f"Rate limit or quota error: {e}")
                if retries > API_CONFIG["MAX_RETRIES"]:
                    raise Exception(f"Rate limit/quota exceeded after {API_CONFIG['MAX_RETRIES']} retries: {e}")
            elif any(keyword in error_msg for keyword in ['401', 'unauthorized', 'api key']):
                raise Exception(f"Authentication error: {e}")
            else:
                if retries > API_CONFIG["MAX_RETRIES"]:
                    raise e
            
            # Progressive backoff with jitter
            base_wait = API_CONFIG["RETRY_DELAY"] * (2 ** (retries - 1))
            jitter = base_wait * 0.1 * (2 * time.time() - int(time.time()))
            wait_time = min(base_wait + jitter, 60)  # Max 60 seconds
            
            debug_print(f"Error: {e}. Retrying in {wait_time:.1f}s... (Attempt {retries}/{API_CONFIG['MAX_RETRIES']})")
            await asyncio.sleep(wait_time)

async def process_scene_description_generation_async(input_data: Dict[str, Any], meta_modules: List[str]) -> Dict[str, Any]:
    """
    Generate scene descriptions based on math information and predefined rules using GPT model
    
    Args:
        input_data: Scene description generation input data
        meta_modules: List of meta description modules to load
    
    Returns:
        Scene description generation processing result
    """
    question_id = input_data.get('question_id', 'unknown')
    debug_print(f"Processing scene description generation for question ID: {question_id}")
    
    try:
        # Create input for the prompt (exclude question_id from API call)
        prompt_input = {
            "scenes": input_data["scenes"]
        }
        
        # Validate input data
        processing_input = ProcessingInput(**prompt_input)
        input_json = json.dumps(processing_input.model_dump(), ensure_ascii=False, indent=2)
        
        # Get meta description content
        meta_content = get_meta_description_content(meta_modules)
        
        # Build system prompt with meta descriptions
        system_prompt = global_strategy
        
        if meta_content:
            meta_content_prompt = f"""
## META DESCRIPTION MODULES

The following meta_description_modules are loaded for this processing:

{meta_content}

These modules provide specialized templates and rules for handling specific types of mathematical information. You MUST apply these modules when they are relevant to the scene content.
"""
            system_prompt += meta_content_prompt
        
        system_prompt += process_prompt
        system_prompt += f"""

You must respond with a valid JSON object that matches this schema:
{ProcessingOutput.model_json_schema()}

CRITICAL JSON FORMATTING REQUIREMENTS:
1. Use proper comma separation between all object properties
2. Use proper comma separation between all array elements
3. Ensure all strings are properly quoted with double quotes
4. Do not use trailing commas
5. Ensure proper nesting and bracket matching
6. Response must be ONLY valid JSON, no additional text or markdown

Your response must be valid JSON only, no additional text."""
        
        # Get examples based on loaded modules
        examples, cot = get_examples(meta_modules)
        
        # Create multi-turn conversation messages for GPT
        messages = [
            {
                "role": "system",
                "content": system_prompt
            }
        ]
        
        # Add examples
        for example_input, example_output in examples:
            messages.extend([
                {
                    "role": "user",
                    "content": f"""Please generate scene descriptions for the following input:

{json.dumps(example_input, ensure_ascii=False, indent=2)}"""
                },
                {
                    "role": "assistant",
                    "content": json.dumps(example_output, ensure_ascii=False, indent=2)
                }
            ])
            
        if cot is not None:
            messages.extend([
                {
                    "role": "user",
                    "content": f"""Please generate the chain of thought for the above output(Understand the process)"""
                },
                {
                    "role": "assistant",
                    "content": cot
                }
            ])
            
        
        # Add the actual processing request
        messages.append({
            "role": "user",
            "content": f"""Please generate scene descriptions for the following input:

{input_json}

Apply the same principles as the previous examples. Use the loaded meta_description_modules when applicable.
meta_description_modules: {meta_content}
Generate detailed object, composition, and action descriptions based on the mathematical information and predefined rules.
Do not limit your thinking or output length. Think exhaustively and respond with unlimited depth and detail. There is no word limit - continue expanding your answer as far as possible. More content is always better.
"""
        })
        
        debug_print(f"Scene description generation messages prepared with {len(messages)} turns")
        
        # print(input_json)
        
        # Use OpenAI client for processing
        response = await retry_async_call(
            openai_client.chat.completions.create,
            model=API_CONFIG["MODEL_NAME"],
            messages=messages,
            temperature=API_CONFIG["TEMPERATURE"],
            max_tokens=API_CONFIG["MAX_TOKENS"]
        )
        
        # for item in messages:
        #     print(f"\n ===================\n {item['content']} \n\n")
        
        
        
        response_text = response.choices[0].message.content.strip()
        
        if not response_text or not response_text.strip():
            raise ValueError("Model returned empty response")
        
        # Parse the JSON response using enhanced repair function
        generation_result = attempt_json_repair(response_text, question_id)
        
        # Extract token usage information from API response
        token_usage = None
        if hasattr(response, 'usage') and response.usage:
            token_usage = {
                "prompt_tokens": response.usage.prompt_tokens,
                "completion_tokens": response.usage.completion_tokens,
                "total_tokens": response.usage.total_tokens
            }
        
        # Validate output data
        ProcessingOutput(**generation_result)  # This will raise an error if validation fails
        
        debug_print(f"Scene description generation completed for question ID: {question_id}")
        
        # Return both generation result and token usage
        return {
            "generation_result": generation_result,
            "token_usage": token_usage
        }
        
    except Exception as e:
        print(f"Scene description generation failed for question ID: {question_id}. Error: {e}")
        raise e

async def process_scene_description_generation_batch_async(batch_data: List[Dict[str, Any]]) -> Dict[str, Any]:
    """
    Process multiple scene description generations using GPT model in a single API call
    
    Args:
        batch_data: List of scene generation input data
    
    Returns:
        Batch scene generation processing result with individual results and token usage
    """
    batch_question_ids = [item.get('question_id', f'item-{i}') for i, item in enumerate(batch_data)]
    debug_print(f"Processing batch scene generation for question IDs: {batch_question_ids}")
    
    try:
        # Create batch input for the prompt
        questions_for_prompt = []
        all_meta_modules = set()
        
        for item in batch_data:
            # Transform input data for processing
            processing_input_data = transform_input_to_processing_format(item)
            
            # Determine meta modules for this item
            meta_modules = determine_meta_modules(item)
            all_meta_modules.update(meta_modules)
            
            questions_for_prompt.append({
                "question_id": item.get('question_id', 'unknown'),
                "scenes": processing_input_data.get("scenes", []),
                "meta_modules": meta_modules
            })
        
        # Validate batch input data
        batch_input = BatchSceneInput(questions=questions_for_prompt)
        input_json = json.dumps(batch_input.model_dump(), ensure_ascii=False, indent=2)
        
        # Get meta description content for all modules used in batch
        meta_content = get_meta_description_content(list(all_meta_modules))
        
        # Build system prompt for batch processing
        system_prompt = global_strategy
        
        if meta_content:
            meta_content_prompt = f"""
## META DESCRIPTION MODULES

The following meta_description_modules are loaded for this batch processing:

{meta_content}

These modules provide specialized templates and rules for handling specific types of mathematical information. You MUST apply these modules when they are relevant to the scene content.
"""
            system_prompt += meta_content_prompt
        
        system_prompt += process_prompt
        
        system_prompt += f"""

BATCH PROCESSING INSTRUCTIONS:
You will receive multiple questions at once. Process each question independently and return scene descriptions for all questions.

You must respond with a valid JSON object that matches this schema:
{BatchSceneOutput.model_json_schema()}

The "results" array should contain one result object for each input question, in the same order.
Each result object should have:
- question_id: The ID of the question being processed
- scenes: Array of generated scene descriptions following the same format as single question processing

CRITICAL JSON FORMATTING REQUIREMENTS:
1. Use proper comma separation between all object properties
2. Use proper comma separation between all array elements
3. Ensure all strings are properly quoted with double quotes
4. Do not use trailing commas
5. Ensure proper nesting and bracket matching
6. Response must be ONLY valid JSON, no additional text or markdown

Your response must be valid JSON only, no additional text."""
        
        # Get examples based on all loaded modules
        examples, cot = get_examples(list(all_meta_modules))
        
        # Create multi-turn conversation messages for GPT with examples
        messages = [
            {
                "role": "system",
                "content": system_prompt
            }
        ]
        
        # Add examples (use fewer examples for batch to save tokens)
        for example_input, example_output in examples[:2]:
            batch_example_input = {
                "questions": [
                    {
                        "question_id": "example_1",
                        "scenes": example_input.get("scenes", []),
                        "meta_modules": []
                    }
                ]
            }
            batch_example_output = {
                "results": [
                    {
                        "question_id": "example_1", 
                        "scenes": example_output.get("scenes", [])
                    }
                ]
            }
            
            messages.extend([
                {
                    "role": "user",
                    "content": f"""Please generate scene descriptions for the following batch input:

{json.dumps(batch_example_input, ensure_ascii=False, indent=2)}"""
                },
                {
                    "role": "assistant",
                    "content": json.dumps(batch_example_output, ensure_ascii=False, indent=2)
                }
            ])
        
            if cot is not None:
                messages.extend([
                    {
                        "role": "user",
                        "content": f"""Please generate the chain of thought for the above output(Understand the process)"""
                    },
                    {
                        "role": "assistant",
                        "content": cot
                    }
                ])
        # Add the actual batch processing request
        messages.append({
            "role": "user",
            "content": f"""Please generate scene descriptions for the following batch input:

{input_json}

Apply the same principles as the previous examples to each question. Use the loaded meta_description_modules when applicable for each question.
Generate detailed object, composition, and action descriptions based on the mathematical information and predefined rules for each question.
Return results for all questions in the results array. Do not limit your thinking or output length. Think exhaustively and respond with unlimited depth and detail for each question."""
        })
        
        debug_print(f"Batch scene generation messages prepared with {len(messages)} turns for {len(batch_data)} questions")
        
        # print(input_json)
        
        # Use OpenAI client for processing
        response = await retry_async_call(
            openai_client.chat.completions.create,
            model=API_CONFIG["MODEL_NAME"],
            messages=messages,
            temperature=API_CONFIG["TEMPERATURE"],
            max_tokens=API_CONFIG["MAX_TOKENS"]
        )
        
        response_text = response.choices[0].message.content.strip()
        
        if not response_text or not response_text.strip():
            raise ValueError("Model returned empty response")
        
        # Parse the JSON response using enhanced repair function
        batch_result = attempt_json_repair(response_text, f"batch-{'-'.join(map(str, batch_question_ids))}")
        
        # Validate output data
        BatchSceneOutput(**batch_result)  # This will raise an error if validation fails
        
        # Extract token usage information from API response
        token_usage = None
        if hasattr(response, 'usage') and response.usage:
            token_usage = {
                "prompt_tokens": response.usage.prompt_tokens,
                "completion_tokens": response.usage.completion_tokens,
                "total_tokens": response.usage.total_tokens
            }
        
        debug_print(f"Batch scene generation completed for question IDs: {batch_question_ids}")
        
        # Return both batch result and token usage
        return {
            "batch_result": batch_result,
            "token_usage": token_usage,
            "batch_size": len(batch_data)
        }
        
    except Exception as e:
        print(f"Batch scene generation failed for question IDs: {batch_question_ids}. Error: {e}")
        raise e

def combine_token_usage(stage1_tokens: Optional[Dict], stage2_tokens: Optional[Dict]) -> Optional[CombinedTokenUsage]:
    """
    Combine token usage from stage 1 and stage 2
    
    Args:
        stage1_tokens: Token usage from stage 1 (math extraction)
        stage2_tokens: Token usage from stage 2 (scene generation)
    
    Returns:
        Combined token usage information
    """
    if not stage1_tokens and not stage2_tokens:
        return None
    
    # Create TokenUsage objects
    stage1_token_usage = None
    if stage1_tokens:
        stage1_token_usage = TokenUsage(
            prompt_tokens=stage1_tokens.get("prompt_tokens", 0),
            completion_tokens=stage1_tokens.get("completion_tokens", 0),
            total_tokens=stage1_tokens.get("total_tokens", 0)
        )
    
    stage2_token_usage = None
    if stage2_tokens:
        stage2_token_usage = TokenUsage(
            prompt_tokens=stage2_tokens.get("prompt_tokens", 0),
            completion_tokens=stage2_tokens.get("completion_tokens", 0),
            total_tokens=stage2_tokens.get("total_tokens", 0)
        )
    
    # Calculate combined totals
    total_prompt = 0
    total_completion = 0
    total_tokens = 0
    
    if stage1_token_usage:
        total_prompt += stage1_token_usage.prompt_tokens
        total_completion += stage1_token_usage.completion_tokens
        total_tokens += stage1_token_usage.total_tokens
    
    if stage2_token_usage:
        total_prompt += stage2_token_usage.prompt_tokens
        total_completion += stage2_token_usage.completion_tokens
        total_tokens += stage2_token_usage.total_tokens
    
    combined_totals = TokenUsage(
        prompt_tokens=total_prompt,
        completion_tokens=total_completion,
        total_tokens=total_tokens
    )
    
    return CombinedTokenUsage(
        stage1_tokens=stage1_token_usage,
        stage2_tokens=stage2_token_usage,
        total_tokens=combined_totals
    )

async def process_single_item(item: Dict[str, Any]) -> Dict[str, Any]:
    """
    Process a single item through scene description generation
    
    Args:
        item: Single item from input file
    
    Returns:
        Processed item with generated scene descriptions and combined token usage
    """
    question_id = item.get('question_id', 'unknown')
    
    try:
        # Transform input data for processing
        processing_input_data = transform_input_to_processing_format(item)
        processing_input_data['question_id'] = question_id  # Preserve question_id for tracking
        
        # Determine meta modules to load from input data directly
        meta_modules = determine_meta_modules(item)  # Pass the full item instead of just scenes
        debug_print(f"Meta modules for {question_id}: {meta_modules}")
        
        # Process scene description generation
        processing_result = await process_scene_description_generation_async(processing_input_data, meta_modules)
        generation_result = processing_result.get("generation_result", {})
        stage2_token_usage = processing_result.get("token_usage")
        
        # Combine original scenes with generated descriptions
        combined_scenes = []
        original_scenes = item.get("scenes", [])
        generated_scenes = generation_result.get("scenes", [])
        
        # Create mapping of generated scenes by scene_id
        generated_scene_map = {scene.get("scene_id"): scene for scene in generated_scenes}
        
        for original_scene in original_scenes:
            scene_id = original_scene.get("scene_id")
            combined_scene = {
                "scene_id": scene_id,
                "scene_math_information": original_scene.get("scene_math_information", []),
                "interfere": original_scene.get("interfere", "none")  # Preserve interfere field in output
            }
            
            # Add generated descriptions if available
            if scene_id in generated_scene_map:
                generated_scene = generated_scene_map[scene_id]
                combined_scene.update({
                    "object": generated_scene.get("object", ""),
                    "composition": generated_scene.get("composition", ""),
                    "action": generated_scene.get("action", "")
                })
            
            combined_scenes.append(combined_scene)
        
        # Get stage 1 token usage from input
        stage1_token_usage = item.get("token_usage")
        
        # Combine token usage from both stages
        combined_token_usage = combine_token_usage(stage1_token_usage, stage2_token_usage)
        
        # Create final result
        final_result = {
            "question_id": question_id,
            "original_question": item.get("original_question", ""),
            "math_ground_truth": item.get("math_ground_truth", ""),
            "scenes": combined_scenes,
            "processing_status": "success",
            "meta_modules_used": meta_modules
        }
        
        # Add combined token usage if available
        if combined_token_usage:
            final_result["token_usage"] = combined_token_usage.model_dump()
        
        debug_print(f"Scene description generation processing completed for question ID: {question_id}")
        return final_result
        
    except Exception as e:
        print(f"Scene description generation processing failed for question ID: {question_id}. Error: {e}")
        
        # Preserve stage 1 token usage in error case
        stage1_token_usage = item.get("token_usage")
        combined_token_usage = combine_token_usage(stage1_token_usage, None)
        
        error_result = {
            "question_id": question_id,
            "original_question": item.get("original_question", ""),
            "math_ground_truth": item.get("math_ground_truth", ""),
            "scenes": item.get("scenes", []),
            "processing_status": "failed",
            "processing_error": str(e)
        }
        
        if combined_token_usage:
            error_result["token_usage"] = combined_token_usage.model_dump()
        
        return error_result

async def process_batch_async(batch_data: List[Dict[str, Any]], semaphore: asyncio.Semaphore) -> List[Dict[str, Any]]:
    """
    Process a batch of items using async parallel processing
    
    Args:
        batch_data: List of items to process
        semaphore: Semaphore to control concurrency
    
    Returns:
        List of processed results
    """
    debug_print(f"Processing batch of {len(batch_data)} items in parallel")
    
    # Create tasks for parallel processing
    tasks = []
    for item in batch_data:
        async def process_with_semaphore(item_data):
            async with semaphore:
                return await process_single_item(item_data)
        
        tasks.append(process_with_semaphore(item))
    
    # Execute all tasks in parallel
    results = await asyncio.gather(*tasks, return_exceptions=True)
    
    # Convert exceptions to error results
    processed_results = []
    for i, result in enumerate(results):
        if isinstance(result, Exception):
            print(f"Error processing item {batch_data[i].get('question_id', f'item-{i}')}: {result}")
            error_item = {
                "question_id": batch_data[i].get('question_id', f'item-{i}'),
                "original_question": batch_data[i].get("original_question", ""),
                "math_ground_truth": batch_data[i].get("math_ground_truth", ""),
                "scenes": [],
                "processing_status": "failed_parallel_processing",
                "processing_error": str(result)
            }
            processed_results.append(error_item)
        else:
            processed_results.append(result)
    
    return processed_results

async def process_batch_with_batch_api_async(batch_data: List[Dict[str, Any]], semaphore: asyncio.Semaphore) -> List[Dict[str, Any]]:
    """
    Process a batch of items using batch API call instead of individual calls
    
    Args:
        batch_data: List of items to process
        semaphore: Semaphore to control concurrency
    
    Returns:
        List of processed results
    """
    debug_print(f"Processing batch of {len(batch_data)} items with single batch API call")
    
    async with semaphore:
        try:
            # Split large batches into smaller sub-batches for API processing
            api_batch_size = API_CONFIG["BATCH_API_SIZE"]
            all_results = []
            
            for i in range(0, len(batch_data), api_batch_size):
                sub_batch = batch_data[i:i+api_batch_size]
                
                # Process sub-batch with single API call
                batch_processing_result = await process_scene_description_generation_batch_async(sub_batch)
                batch_result = batch_processing_result.get("batch_result", {})
                stage2_token_usage = batch_processing_result.get("token_usage")
                batch_size = batch_processing_result.get("batch_size", len(sub_batch))
                
                # Extract individual results from batch response
                batch_results = batch_result.get("results", [])
                
                # Create final results for each item in sub-batch
                for j, item in enumerate(sub_batch):
                    question_id = item.get('question_id', f'item-{i+j}')
                    
                    # Find corresponding result in batch response
                    item_result = None
                    for result in batch_results:
                        if result.get("question_id") == question_id:
                            item_result = result
                            break
                    
                    if item_result:
                        # Combine original scenes with generated descriptions
                        combined_scenes = []
                        original_scenes = item.get("scenes", [])
                        generated_scenes = item_result.get("scenes", [])
                        
                        # Create mapping of generated scenes by scene_id
                        generated_scene_map = {scene.get("scene_id"): scene for scene in generated_scenes}
                        
                        for original_scene in original_scenes:
                            scene_id = original_scene.get("scene_id")
                            combined_scene = {
                                "scene_id": scene_id,
                                "scene_math_information": original_scene.get("scene_math_information", []),
                                "interfere": original_scene.get("interfere", "none")
                            }
                            
                            # Add generated descriptions if available
                            if scene_id in generated_scene_map:
                                generated_scene = generated_scene_map[scene_id]
                                combined_scene.update({
                                    "object": generated_scene.get("object", ""),
                                    "composition": generated_scene.get("composition", ""),
                                    "action": generated_scene.get("action", "")
                                })
                            
                            combined_scenes.append(combined_scene)
                        
                        # Get stage 1 token usage from input
                        stage1_token_usage = item.get("token_usage")
                        
                        # Distribute stage 2 token usage proportionally
                        proportional_stage2_tokens = None
                        if stage2_token_usage and batch_size > 0:
                            proportional_stage2_tokens = {
                                "prompt_tokens": int(stage2_token_usage.get("prompt_tokens", 0) / batch_size),
                                "completion_tokens": int(stage2_token_usage.get("completion_tokens", 0) / batch_size),
                                "total_tokens": int(stage2_token_usage.get("total_tokens", 0) / batch_size)
                            }
                        
                        # Combine token usage from both stages
                        combined_token_usage = combine_token_usage(stage1_token_usage, proportional_stage2_tokens)
                        
                        # Create final result
                        final_result = {
                            "question_id": question_id,
                            "original_question": item.get("original_question", ""),
                            "math_ground_truth": item.get("math_ground_truth", ""),
                            "scenes": combined_scenes,
                            "processing_status": "success"
                        }
                        
                        # Add combined token usage if available
                        if combined_token_usage:
                            final_result["token_usage"] = combined_token_usage.model_dump()
                        
                        all_results.append(final_result)
                    else:
                        # Create error result if not found in batch response
                        stage1_token_usage = item.get("token_usage")
                        combined_token_usage = combine_token_usage(stage1_token_usage, None)
                        
                        error_result = {
                            "question_id": question_id,
                            "original_question": item.get("original_question", ""),
                            "math_ground_truth": item.get("math_ground_truth", ""),
                            "scenes": item.get("scenes", []),
                            "processing_status": "failed_batch_missing_result",
                            "processing_error": f"Result not found in batch response for question_id: {question_id}"
                        }
                        
                        if combined_token_usage:
                            error_result["token_usage"] = combined_token_usage.model_dump()
                        
                        all_results.append(error_result)
            
            return all_results
            
        except Exception as e:
            print(f"Error processing batch with batch API: {e}")
            # Create error results for all items in batch
            error_results = []
            for i, item in enumerate(batch_data):
                question_id = item.get('question_id', f'item-{i}')
                stage1_token_usage = item.get("token_usage")
                combined_token_usage = combine_token_usage(stage1_token_usage, None)
                
                error_result = {
                    "question_id": question_id,
                    "original_question": item.get("original_question", ""),
                    "math_ground_truth": item.get("math_ground_truth", ""),
                    "scenes": item.get("scenes", []),
                    "processing_status": "failed_batch_api_exception",
                    "processing_error": str(e)
                }
                
                if combined_token_usage:
                    error_result["token_usage"] = combined_token_usage.model_dump()
                
                error_results.append(error_result)
            return error_results

async def test_connection():
    """Test OpenAI API connection"""
    print("Testing OpenAI API connection...")
    
    if not os.environ.get("OPENAI_API_KEY_IMAGE"):
        print("Warning: OPENAI_API_KEY_IMAGE environment variable not set")
        return False
    
    try:
        test_response = await openai_client.chat.completions.create(
            model=API_CONFIG["MODEL_NAME"],
            messages=[
                {"role": "user", "content": "Hello! Please respond with a simple JSON object containing a 'status' field with value 'ok'."}
            ],
            max_tokens=100,
            temperature=0
        )
        
        response_content = test_response.choices[0].message.content
        print(f"OpenAI API connection successful")
        print(f"Test response: {response_content[:200]}...")
        return True
    except Exception as e:
        print(f"OpenAI API connection failed: {e}")
        return False

async def main_async(input_file: str, output_file: str = None, anomalies_file: str = None, 
                    batch_size: int = API_CONFIG["BATCH_SIZE"]):
    """
    Main function for scene description generation processing
    
    Args:
        input_file: Input file path
        output_file: Output file path for successful results
        anomalies_file: Output file path for failed results
        batch_size: Number of items to process in parallel per batch
    """
    debug_print(f"Starting scene description generation processing")
    
    # Test API connection
    connection_ok = await test_connection()
    if not connection_ok:
        print("Cannot proceed without valid OpenAI API connection")
        return
    
    # Generate timestamped filenames if not provided
    if output_file is None:
        output_file = create_timestamped_filename(OUTPUT_FILE_BASE + ".json")
    if anomalies_file is None:
        anomalies_file = create_timestamped_filename(ANOMALIES_FILE_BASE + ".json")
    
    # Create output directory if needed
    output_dir = os.path.dirname(output_file) if os.path.dirname(output_file) else "."
    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)
    
    print(f"\nLoading data from {input_file}")
    print(f"Output will be saved to: {output_file}")
    if anomalies_file:
        print(f"Anomalies will be saved to: {anomalies_file}")
    
    try:
        with open(input_file, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        print(f"Loaded {len(data)} items from input file.")
            
    except Exception as e:
        print(f"Error loading input file: {e}")
        return
    
    start_time = time.time()
    total_items = len(data)
    all_results = []
    successful_results = []
    anomalies = []
    
    # Create semaphore for controlling concurrency
    semaphore = asyncio.Semaphore(batch_size)
    
    # Determine processing mode
    use_batch_api = API_CONFIG["ENABLE_BATCH_PROCESSING"]
    processing_mode = "Batch API Processing" if use_batch_api else "Individual API Processing"
    
    print(f"\nStarting Scene Description Generation Processing")
    print(f"   Model: {API_CONFIG['MODEL_NAME']} (GPT)")
    print(f"   Processing Mode: {processing_mode}")
    print(f"   Batch Size: {batch_size} (Parallel batches)")
    if use_batch_api:
        print(f"   API Batch Size: {API_CONFIG['BATCH_API_SIZE']} (Questions per API call)")
    
    # Process data in batches using appropriate processing function
    for i in range(0, total_items, batch_size):
        batch = data[i:i+batch_size]
        batch_start = time.time()
        total_batches = (total_items + batch_size - 1) // batch_size
        current_batch = i//batch_size + 1
        
        print(f"\nProcessing batch {current_batch}/{total_batches} ({len(batch)} items) in parallel")
        
        try:
            # Choose processing function based on mode
            if use_batch_api:
                batch_results = await process_batch_with_batch_api_async(batch, semaphore)
            else:
                batch_results = await process_batch_async(batch, semaphore)
            
            all_results.extend(batch_results)
            
            # Separate successful and failed results
            for result in batch_results:
                if result.get("processing_status") == "success":
                    successful_results.append(result)
                else:
                    anomalies.append(result)
                    
        except Exception as e:
            print(f"Error processing batch {current_batch}: {e}")
            # Add failed items to anomalies only
            for j, item in enumerate(batch):
                stage1_token_usage = item.get("token_usage")
                combined_token_usage = combine_token_usage(stage1_token_usage, None)
                
                error_item = {
                    "question_id": item.get('question_id', f'item-{i+j}'),
                    "original_question": item.get("original_question", ""),
                    "math_ground_truth": item.get("math_ground_truth", ""),
                    "scenes": item.get("scenes", []),
                    "processing_status": "failed_batch_exception",
                    "processing_error": str(e)
                }
                
                if combined_token_usage:
                    error_item["token_usage"] = combined_token_usage.model_dump()
                
                all_results.append(error_item)
                anomalies.append(error_item)
        
        # Display progress (rest of the progress display code remains the same)
        batch_time = time.time() - batch_start
        elapsed = time.time() - start_time
        items_processed = len(all_results)
        successful_count = len(successful_results)
        anomalies_count = len(anomalies)
        progress_percent = items_processed/total_items*100 if total_items > 0 else 0
        items_per_second = items_processed / elapsed if elapsed > 0 else 0
        
        if items_processed > 0:
            estimated_total = (elapsed / items_processed) * total_items
            estimated_remaining = estimated_total - elapsed
        else:
            estimated_remaining = float('inf')
        
        print(f"Progress: {items_processed}/{total_items} items ({progress_percent:.1f}%)")
        print(f"Successful: {successful_count}, Anomalies: {anomalies_count}")
        print(f"Speed: {items_per_second:.2f} items/sec ({processing_mode.lower()})")
        print(f"Time elapsed: {elapsed:.1f}s, estimated remaining: {'inf' if estimated_remaining == float('inf') else f'{estimated_remaining:.1f}s'}")
        print(f"Batch processing time: {batch_time:.1f}s ({batch_time/len(batch):.1f}s per item average)")
    
    # Save successful results only to output file
    print(f"\nSaving successful results to {output_file}")
    try:
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(successful_results, f, ensure_ascii=False, indent=2)
        print(f"{len(successful_results)} successful results saved.")
    except Exception as e:
        print(f"Error saving results: {e}")
    
    # Save anomalies to separate file
    if anomalies:
        print(f"\nSaving anomalies to {anomalies_file}")
        try:
            with open(anomalies_file, 'w', encoding='utf-8') as f:
                json.dump(anomalies, f, ensure_ascii=False, indent=2)
            print(f"{len(anomalies)} anomalies saved for review.")
        except Exception as e:
            print(f"Error saving anomalies: {e}")
    else:
        print(f"No anomalies detected.")
    
    # Print summary
    print_summary_statistics(successful_results, anomalies, time.time() - start_time, total_items)

def print_summary_statistics(successful_results: List[Dict], anomalies: List[Dict], 
                           total_time: float, total_items: int):
    """
    Print comprehensive summary statistics for parallel processing
    
    Args:
        successful_results: List of successfully processed items
        anomalies: List of failed items
        total_time: Total processing time in seconds
        total_items: Total number of input items
    """
    successful_count = len(successful_results)
    failed_count = len(anomalies)
    
    # Determine processing mode description
    processing_mode = "Batch API processing" if API_CONFIG["ENABLE_BATCH_PROCESSING"] else "Individual API processing"
    
    print("\n============================== Scene Description Generation Processing Summary ==============================")
    print(f"Model: {API_CONFIG['MODEL_NAME']} (GPT)")
    print(f"Total items loaded: {total_items}")
    print(f"Successfully processed items: {successful_count}")
    print(f"Items with errors (anomalies): {failed_count}")
    print(f"Success rate: {successful_count/total_items*100:.1f}%" if total_items > 0 else "Success rate: 0%")
    print(f"Total time: {total_time:.1f}s")
    print(f"Average time per item: {total_time/total_items:.1f}s" if total_items > 0 else "Average time per item: N/A")
    print(f"Processing mode: {processing_mode}, batch size: {API_CONFIG['BATCH_SIZE']}")
    if API_CONFIG["ENABLE_BATCH_PROCESSING"]:
        print(f"API batch size: {API_CONFIG['BATCH_API_SIZE']} questions per API call")
    
    # Analyze meta module usage distribution
    if successful_results:
        print(f"\nMeta Module Usage Analysis:")
        module_stats = {}
        
        # Token usage statistics
        total_stage1_prompt = 0
        total_stage1_completion = 0
        total_stage1_tokens = 0
        total_stage2_prompt = 0
        total_stage2_completion = 0
        total_stage2_tokens = 0
        total_combined_prompt = 0
        total_combined_completion = 0
        total_combined_tokens = 0
        items_with_token_info = 0
        
        for item in successful_results:
            modules = item.get("meta_modules_used", [])
            
            if not modules:
                module_stats["no_modules"] = module_stats.get("no_modules", 0) + 1
            else:
                for module in modules:
                    module_stats[module] = module_stats.get(module, 0) + 1
            
            # Token usage calculation
            token_usage = item.get("token_usage")
            if token_usage:
                stage1_tokens = token_usage.get("stage1_tokens")
                stage2_tokens = token_usage.get("stage2_tokens")
                combined_tokens = token_usage.get("total_tokens")
                
                if stage1_tokens:
                    total_stage1_prompt += stage1_tokens.get("prompt_tokens", 0)
                    total_stage1_completion += stage1_tokens.get("completion_tokens", 0)
                    total_stage1_tokens += stage1_tokens.get("total_tokens", 0)
                
                if stage2_tokens:
                    total_stage2_prompt += stage2_tokens.get("prompt_tokens", 0)
                    total_stage2_completion += stage2_tokens.get("completion_tokens", 0)
                    total_stage2_tokens += stage2_tokens.get("total_tokens", 0)
                
                if combined_tokens:
                    total_combined_prompt += combined_tokens.get("prompt_tokens", 0)
                    total_combined_completion += combined_tokens.get("completion_tokens", 0)
                    total_combined_tokens += combined_tokens.get("total_tokens", 0)
                    items_with_token_info += 1
        
        print(f"   Module usage distribution:")
        for module, count in sorted(module_stats.items()):
            percentage = count / successful_count * 100 if successful_count > 0 else 0
            print(f"   {module}: {count} items ({percentage:.1f}%)")
        
        # Print token usage statistics
        if items_with_token_info > 0:
            avg_stage1_prompt = total_stage1_prompt / items_with_token_info
            avg_stage1_completion = total_stage1_completion / items_with_token_info
            avg_stage1_total = total_stage1_tokens / items_with_token_info
            avg_stage2_prompt = total_stage2_prompt / items_with_token_info
            avg_stage2_completion = total_stage2_completion / items_with_token_info
            avg_stage2_total = total_stage2_tokens / items_with_token_info
            avg_combined_prompt = total_combined_prompt / items_with_token_info
            avg_combined_completion = total_combined_completion / items_with_token_info
            avg_combined_total = total_combined_tokens / items_with_token_info
            
            print(f"\n   Token Usage Statistics:")
            print(f"   Items with token info: {items_with_token_info}/{successful_count}")
            print(f"\n   Stage 1 (Math Extraction) Tokens:")
            print(f"   Total prompt tokens: {total_stage1_prompt}")
            print(f"   Total completion tokens: {total_stage1_completion}")
            print(f"   Total tokens: {total_stage1_tokens}")
            print(f"   Average per item: {avg_stage1_prompt:.1f}/{avg_stage1_completion:.1f}/{avg_stage1_total:.1f}")
            
            print(f"\n   Stage 2 (Scene Generation) Tokens:")
            print(f"   Total prompt tokens: {total_stage2_prompt}")
            print(f"   Total completion tokens: {total_stage2_completion}")
            print(f"   Total tokens: {total_stage2_tokens}")
            print(f"   Average per item: {avg_stage2_prompt:.1f}/{avg_stage2_completion:.1f}/{avg_stage2_total:.1f}")
            
            print(f"\n   Combined (Both Stages) Tokens:")
            print(f"   Total prompt tokens: {total_combined_prompt}")
            print(f"   Total completion tokens: {total_combined_completion}")
            print(f"   Total tokens: {total_combined_tokens}")
            print(f"   Average per item: {avg_combined_prompt:.1f}/{avg_combined_completion:.1f}/{avg_combined_total:.1f}")
    
    print("================================================================================================")

def main(input_file: str, output_file: str = None, anomalies_file: str = None, 
         batch_size: int = API_CONFIG["BATCH_SIZE"]):
    """Synchronous wrapper for main_async function"""
    return asyncio.run(main_async(input_file, output_file, anomalies_file, batch_size))

def parse_arguments():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser(description='Scene Description Generation Processing Tool - GPT Version')
    
    parser.add_argument('--input', '-i', type=str, default=INPUT_FILE,
                       help=f'Input file path (default: {INPUT_FILE})')
    
    parser.add_argument('--output', '-o', type=str, default=None,
                       help='Output file path (default: auto-generated with timestamp)')
    
    parser.add_argument('--anomalies', '-a', type=str, default=None,
                       help='Anomalies file path (default: auto-generated with timestamp)')
    
    parser.add_argument('--batch-size', '-b', type=int, default=API_CONFIG["BATCH_SIZE"],
                       help=f'Batch size for parallel processing (default: {API_CONFIG["BATCH_SIZE"]})')
    
    parser.add_argument('--batch', action='store_true',
                       help='Enable batch API processing mode (process multiple questions per API call)')
    
    parser.add_argument('--api-batch-size', type=int, default=API_CONFIG["BATCH_API_SIZE"],
                       help=f'Number of questions per batch API call when batch mode is enabled (default: {API_CONFIG["BATCH_API_SIZE"]})')
    
    parser.add_argument('--debug', '-d', action='store_true',
                       help='Enable debug mode')
    
    parser.add_argument('--test', '-t', action='store_true',
                       help='Run tests only')
    
    return parser.parse_args()

if __name__ == "__main__":
    print("Scene Description Generation Processing Tool - GPT Version")
    print("=" * 80)
    print("Features:")
    print("   - Scene descriptions generation based on math information and rules (GPT)")
    print("   - Individual API calls or Batch API calls (--batch)")
    print("   - Token usage tracking and combining from multiple stages")
    print("   - Async parallel processing for maximum efficiency")
    print("   - Comprehensive meta description modules support")
    print("   - Combines original data with generated scene descriptions")
    print("")
    
    # Parse command line arguments
    args = parse_arguments()
    
    # Set debug mode
    DEBUG = args.debug
    
    # Configure batch processing mode
    API_CONFIG["ENABLE_BATCH_PROCESSING"] = args.batch
    API_CONFIG["BATCH_API_SIZE"] = args.api_batch_size
    
    # Check if environment is set up correctly
    if not os.getenv("OPENAI_API_KEY_IMAGE"):
        print("Warning: OPENAI_API_KEY environment variable not set")
        print("   Please set your OpenAI API key before running processing")
        print("")
    
    # Generate timestamped output files if not provided
    if args.output is None:
        args.output = create_timestamped_filename(OUTPUT_FILE_BASE + ".json")
    
    if args.anomalies is None:
        args.anomalies = create_timestamped_filename(ANOMALIES_FILE_BASE + ".json")
    
    print(f"Configuration:")
    print(f"   Input file: {args.input}")
    print(f"   Batch size: {args.batch_size}")
    print(f"   Batch API mode: {'Enabled' if args.batch else 'Disabled'}")
    if args.batch:
        print(f"   API batch size: {args.api_batch_size}")
    print(f"   Output file: {args.output}")
    print(f"   Anomalies file: {args.anomalies}")
    print(f"   Debug mode: {args.debug}")
    print("")
    
    # Check if input file exists
    if not os.path.exists(args.input):
        print(f"Input file not found: {args.input}")
        print("   Please ensure the input file exists before running")
        exit(1)
    else:
        print(f"Input file found: {args.input}")
        processing_mode = "batch API" if args.batch else "individual API"
        print(f"Starting processing with async parallel execution ({processing_mode} calls)...")
        
        main(args.input, args.output, args.anomalies, args.batch_size)
