import asyncio
import aiohttp
import json
import time
from typing import List, Dict, Any, Optional
from utils import read_json, write_json

API_KEY = ""

class AsyncGPTProcessor:
    def __init__(self, max_concurrent: int = 10, max_retries: int = 3, retry_delay: float = 1.0):
        self.max_concurrent = max_concurrent
        self.max_retries = max_retries
        self.retry_delay = retry_delay
        self.semaphore = asyncio.Semaphore(max_concurrent)
        self.session: Optional[aiohttp.ClientSession] = None
        
    async def __aenter__(self):
        self.session = aiohttp.ClientSession()
        return self
        
    async def __aexit__(self, exc_type, exc_val, exc_tb):
        if self.session:
            await self.session.close()
    
    async def call_gpt_api_with_retry(self, prompt: str, model: str = "o4-mini") -> Optional[str]:
        """GPT API call with retry mechanism"""
        for attempt in range(self.max_retries):
            try:
                async with self.semaphore:
                    return await self._call_gpt_api(prompt, model)
            except Exception as e:
                if attempt < self.max_retries - 1:
                    print(f"API call failed (attempt {attempt + 1}/{self.max_retries}): {e}")
                    await asyncio.sleep(self.retry_delay * (attempt + 1))  # Exponential backoff
                else:
                    print(f"API call finally failed: {e}")
                    return None
        return None
    
    async def _call_gpt_api(self, prompt: str, model: str = "o4-mini") -> str:
        """Actual GPT API call"""
        url = "https://api.openai.com/v1/chat/completions"
        
        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {API_KEY}"
        }
        
        data = {
            "model": model,
            "messages": [
                {
                    "role": "user",
                    "content": prompt
                }
            ]
        }
        
        async with self.session.post(url, headers=headers, json=data) as response:
            response.raise_for_status()
            response_data = await response.json()
            return response_data['choices'][0]['message']['content']
    
    async def process_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
        """Process a single item"""
        prompt = f"You are a tutor. You are given a problem and a reference solution. " \
        f"Your job is to provide a concise and helpful hint for this problem. " \
        f"The hint should help the student learn the core concept (e.g. formula, lemma, or necessary knowledge) needed to solve this problem. " \
        f"The hint should be concise, to the point, but high level. Do not include any detailed steps or calculations or the final answer. " \
        f"Here is the problem: {item['problem']}\n\n" \
        f"Here is the reference solution: {item['gpt_response']}\n\n" \
        f"Now, please provide a concise hint for this problem."
        
        result = await self.call_gpt_api_with_retry(prompt)
        item['gpt_hint'] = result
        
        return item
    
    async def process_batch(self, data: List[Dict[str, Any]], batch_size: int = 20, output_file: str = None) -> List[Dict[str, Any]]:
        """Process data in batches, saving after each batch is completed"""
        processed_items = []
        total_batches = (len(data) + batch_size - 1) // batch_size
        
        for i in range(0, len(data), batch_size):
            batch = data[i:i + batch_size]
            batch_num = i // batch_size + 1
            
            print(f"Processing batch {batch_num}/{total_batches} ({len(batch)} items)")
            
            # Create tasks
            tasks = [self.process_item(item) for item in batch]
            
            # Wait for batch completion
            batch_results = await asyncio.gather(*tasks, return_exceptions=True)
            
            # Process results
            for j, result in enumerate(batch_results):
                if isinstance(result, Exception):
                    print(f"Error occurred while processing item {i + j}: {result}")
                    batch[j]['gpt_hint'] = None
                else:
                    batch[j] = result
            
            processed_items.extend(batch)
            
            # Save after each batch is completed
            if output_file:
                # Create temporary saved data (including processed and unprocessed items)
                temp_data = processed_items.copy()
                if i + batch_size < len(data):
                    # Add unprocessed items (maintain original state)
                    temp_data.extend(data[i + batch_size:])
                
                # Save to temporary file
                temp_output_file = output_file.replace('.json', f'_temp.json')
                write_json(temp_data, temp_output_file)
                print(f"Batch {batch_num} completed, saved to: {temp_output_file}")
            
            # Delay between batches
            if i + batch_size < len(data):
                await asyncio.sleep(0.5)
        
        return processed_items

async def main():
    # Read data
    print("Reading data...")
    data = read_json("./data/math_science_with_gpt_responses_all_correct_8k.json")
    
    print(f"Starting to process {len(data)} items...")
    start_time = time.time()
    
    # Set output file path
    output_file = "./data/math_science_final_8k.json"
    
    async with AsyncGPTProcessor(max_concurrent=10, max_retries=3) as processor:
        processed_data = await processor.process_batch(data, batch_size=20, output_file=output_file)
    
    # Calculate statistics
    total_time = time.time() - start_time
    successful_responses = sum(1 for item in processed_data if item.get('gpt_hint') is not None)
    
    write_json(processed_data, output_file)
    
    # Display statistics
    print(f"\n=== Processing Complete ===")
    print(f"Total processing time: {total_time:.2f} seconds")
    print(f"Successfully processed: {successful_responses}/{len(processed_data)} items")
    print(f"Success rate: {successful_responses/len(processed_data)*100:.1f}%")
    print(f"Final results saved to: {output_file}")
    print(f"Temporary file format: {output_file.replace('.json', '_temp.json')}")

if __name__ == "__main__":
    asyncio.run(main()) 