import asyncio
import aiohttp
import json
import time
from typing import List, Dict, Any, Optional
from utils import read_json, write_json, last_boxed_only_string, remove_boxed, is_math_verify_equiv

API_KEY = ""

class AsyncGPTProcessor:
    def __init__(self, max_concurrent: int = 10, max_retries: int = 3, retry_delay: float = 1.0):
        self.max_concurrent = max_concurrent
        self.max_retries = max_retries
        self.retry_delay = retry_delay
        self.semaphore = asyncio.Semaphore(max_concurrent)
        self.session: Optional[aiohttp.ClientSession] = None
        
    async def __aenter__(self):
        self.session = aiohttp.ClientSession()
        return self
        
    async def __aexit__(self, exc_type, exc_val, exc_tb):
        if self.session:
            await self.session.close()
    
    async def call_gpt_api_with_retry(self, prompt: str, model: str = "o4-mini") -> Optional[str]:
        """GPT API call with retry mechanism"""
        for attempt in range(self.max_retries):
            try:
                async with self.semaphore:
                    return await self._call_gpt_api(prompt, model)
            except Exception as e:
                if attempt < self.max_retries - 1:
                    print(f"API call failed (attempt {attempt + 1}/{self.max_retries}): {e}")
                    await asyncio.sleep(self.retry_delay * (attempt + 1))  # Exponential backoff
                else:
                    print(f"API call finally failed: {e}")
                    return None
        return None
    
    async def _call_gpt_api(self, prompt: str, model: str = "o4-mini") -> str:
        """Actual GPT API call"""
        url = "https://api.openai.com/v1/chat/completions"
        
        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {API_KEY}"
        }
        
        data = {
            "model": model,
            "messages": [
                {
                    "role": "user",
                    "content": prompt
                }
            ]
        }
        
        async with self.session.post(url, headers=headers, json=data) as response:
            response.raise_for_status()
            response_data = await response.json()
            return response_data['choices'][0]['message']['content']
    
    async def process_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
        """Process a single item, including answer validation and retry logic"""
        max_attempts = 5
        
        for attempt in range(max_attempts):
            prompt = f"{item['problem']}\nThink step by step and put the final answer in \\boxed{{}}. " \
            f"You should clearly state the step number in the beginning of each step using this format: " \
            f"Step 1: <step 1 reasoning>\n\n\n\nStep 2: <step 2 reasoning>...\n\n\n\nAnswer: <final answer in \\boxed{{}}>" \
            f"If the question is a multiple choice question, you should provide the final answer in only one capital letter. " \
            f"For example, if the answer is A, you should put the final answer in \\boxed{{A}}."
            
            result = await self.call_gpt_api_with_retry(prompt)
            item['gpt_response'] = result
            
            # Validate if the answer is correct
            if result:
                pred = last_boxed_only_string(result)
                if pred:
                    pred = remove_boxed(pred)
                    if is_math_verify_equiv(item['expected_answer'], pred):
                        # Answer is correct, no retry needed
                        print(f"Item {item.get('id', 'unknown')} answer correct (attempt {attempt + 1})")
                        return item
                    else:
                        print(f"Item {item.get('id', 'unknown')} answer incorrect (attempt {attempt + 1}/{max_attempts})")
                        if attempt < max_attempts - 1:
                            # Add retry prompt
                            prompt += f"\n\nYour previous answer was incorrect. Please try again and make sure your final answer is correct."
                        else:
                            print(f"Item {item.get('id', 'unknown')} reached maximum retry attempts, keeping last response")
                            return item
                else:
                    print(f"Item {item.get('id', 'unknown')} unable to extract boxed answer (attempt {attempt + 1}/{max_attempts})")
                    if attempt < max_attempts - 1:
                        prompt += f"\n\nYour previous response did not contain a properly formatted answer in \\boxed{{}}. Please provide your final answer in \\boxed{{}} format."
                    else:
                        print(f"Item {item.get('id', 'unknown')} reached maximum retry attempts, keeping last response")
                        return item
            else:
                print(f"Item {item.get('id', 'unknown')} API call failed (attempt {attempt + 1}/{max_attempts})")
                if attempt == max_attempts - 1:
                    return item
        
        return item
    
    async def process_batch(self, data: List[Dict[str, Any]], batch_size: int = 20, output_file: str = None) -> List[Dict[str, Any]]:
        """Process data in batches, saving after each batch is completed"""
        processed_items = []
        total_batches = (len(data) + batch_size - 1) // batch_size
        
        for i in range(0, len(data), batch_size):
            batch = data[i:i + batch_size]
            batch_num = i // batch_size + 1
            
            print(f"Processing batch {batch_num}/{total_batches} ({len(batch)} items)")
            
            # Create tasks
            tasks = [self.process_item(item) for item in batch]
            
            # Wait for batch completion
            batch_results = await asyncio.gather(*tasks, return_exceptions=True)
            
            # Process results
            for j, result in enumerate(batch_results):
                if isinstance(result, Exception):
                    print(f"Error occurred while processing item {i + j}: {result}")
                    batch[j]['gpt_response'] = None
                else:
                    batch[j] = result
            
            processed_items.extend(batch)
            
            # Save after each batch is completed
            if output_file:
                # Create temporary saved data (including processed and unprocessed items)
                temp_data = processed_items.copy()
                if i + batch_size < len(data):
                    # Add unprocessed items (maintain original state)
                    temp_data.extend(data[i + batch_size:])
                
                # Save to temporary file
                temp_output_file = output_file.replace('.json', f'_temp.json')
                write_json(temp_data, temp_output_file)
                print(f"Batch {batch_num} completed, saved to: {temp_output_file}")
            
            # Delay between batches
            if i + batch_size < len(data):
                await asyncio.sleep(0.5)
        
        return processed_items

async def main():
    # Read data
    print("Reading data...")
    data = read_json("./data/training_data.json")
    
    print(f"Starting to process {len(data)} items...")
    start_time = time.time()
    
    # Set output file path
    output_file = "./data/training_data_with_gpt_reasoning.json"
    
    async with AsyncGPTProcessor(max_concurrent=10, max_retries=3) as processor:
        processed_data = await processor.process_batch(data, batch_size=20, output_file=output_file)
    
    # Calculate statistics
    total_time = time.time() - start_time
    successful_responses = sum(1 for item in processed_data if item.get('gpt_response') is not None)
    
    # Calculate the number of correct answers
    correct_answers = 0
    not_correct_samples = []
    for item in processed_data:
        if item.get('gpt_response'):
            pred = last_boxed_only_string(item['gpt_response'])
            if pred:
                pred = remove_boxed(pred)
                if is_math_verify_equiv(item['expected_answer'], pred):
                    correct_answers += 1
                else:
                    not_correct_samples.append(item)
            else:
                not_correct_samples.append(item)
        else:
            not_correct_samples.append(item)
    
    write_json(processed_data, output_file)
    
    # Display statistics
    print(f"\n=== Processing Complete ===")
    print(f"Total processing time: {total_time:.2f} seconds")
    print(f"Successfully processed: {successful_responses}/{len(processed_data)} items")
    print(f"Success rate: {successful_responses/len(processed_data)*100:.1f}%")
    print(f"Correct answers: {correct_answers}/{len(processed_data)} items")
    print(f"Accuracy rate: {correct_answers/len(processed_data)*100:.1f}%")
    print(f"Incorrect samples count: {len(not_correct_samples)}")
    print(f"Final results saved to: {output_file}")
    print(f"Temporary file format: {output_file.replace('.json', '_temp.json')}")
    
    # Save incorrect samples
    if not_correct_samples:
        incorrect_output_file = output_file.replace('.json', '_incorrect.json')
        write_json(not_correct_samples, incorrect_output_file)
        print(f"Incorrect samples saved to: {incorrect_output_file}")

if __name__ == "__main__":
    asyncio.run(main()) 