import os
import json
import asyncio
import aiofiles
from tqdm.asyncio import tqdm
from tenacity import retry, stop_after_attempt, wait_exponential

class ChatCompletionRequester:
    def __init__(self, model, client, max_concurrent=10):
        self.model = model
        self.client = client
        self.max_concurrent = max_concurrent

    def retry_error_callback(self, retry_state):
        exception = retry_state.outcome.exception()
        print(f"Retry attempt {retry_state.attempt_number} failed: {type(exception).__name__} - {str(exception)}")
        return None

    async def save_temp_results(self, results, save_path, current_count):
        temp_file = f"./results/{save_path}.jsonl"
        last_file = f"./results/{save_path}.jsonl"

        os.makedirs(os.path.dirname(temp_file), exist_ok=True)
        
        if os.path.exists(temp_file):
            try:
                os.replace(temp_file, last_file)
            except Exception as e:
                print(f"Failed to rename old temp file: {e}")
        
        # async with aiofiles.open(temp_file, 'a', encoding='utf-8') as f:
        #     await f.write(json.dumps(results, ensure_ascii=False, indent=2))
        async with aiofiles.open(temp_file, 'a', encoding='utf-8') as f:
            for item in results:
                line = json.dumps(item, ensure_ascii=False)
                await f.write(line + "\n")

    @retry(stop=stop_after_attempt(10), wait=wait_exponential(multiplier=1, min=1, max=15), retry_error_callback=retry_error_callback)
    async def get_chat_completion(self, message: str, semaphore) -> str:
        try:
            async with semaphore:  
                response = await self.client.chat.completions.create(
                    model=self.model,
                    messages=[
                        {"role": "system", "content": message},
                        {"role": "user", "content": message}
                    ],
                    max_tokens=2048,
                )
                response_result = response.choices[0].message.content
                return response_result
        except Exception as e:
            print(f"Error in get_chat_completion for message: {type(e).__name__} - {str(e)}")
            raise

    async def request_model(self, prompts, init_variables, python_codes, prefix, difficulties, answers, save_path="results"):
        semaphore = asyncio.Semaphore(self.max_concurrent)
        results = []
        completed_count = 0

        async def wrapped_get_chat_completion(prompt, init_var, python_code, prefix, difficulty, answer, index):
            nonlocal completed_count
            infer_prompt = prefix + prompt
            try:
                result = await self.get_chat_completion(infer_prompt, semaphore)
                new_ans = {
                    "model": self.model,
                    "prompt": prompt,
                    "difficulty": difficulty,
                    "init_variables": init_var,
                    "python_code": python_code,
                    "response": result,
                    "answer": answer
                }
                results.append(new_ans)
                completed_count += 1
                # 每完成 1000 条保存一次
                if completed_count % 1000 == 0:
                    await self.save_temp_results(results, save_path, completed_count)
                return index, result
            except Exception as e:
                print(f"Task failed after all retries with error: {e}")
                return index, None

        tasks = [wrapped_get_chat_completion(prompt, init_var, python_code, prefix, difficulty, answer, i) for i, (prompt, init_var, python_code, difficulty, answer) in enumerate(zip(prompts, init_variables, python_codes, difficulties, answers))]
        
        for future in tqdm.as_completed(tasks, total=len(tasks), desc="Processing prompts"):
            index, result = await future
        
        await self.save_temp_results(results, save_path, "final")
        
        return results