from openai import AsyncOpenAI
from openai.types import CompletionUsage
from openai.types.completion_usage import CompletionTokensDetails, PromptTokensDetails
from openai._types import NOT_GIVEN
from openai.types.chat.chat_completion import ChatCompletion
from volcenginesdkarkruntime import AsyncArk
from google import genai
import os
import json
import asyncio
import time
import tempfile
from datetime import datetime
from olym_gen.utils.utils import get_logger
from qcloud_cos import CosConfig, CosS3Client
import requests
from hashlib import sha512
from collections import defaultdict
from abc import ABC, abstractmethod

from pathlib import Path
from functools import cached_property
from typing import Literal, Any, get_args

logger = get_logger()

Provider = Literal["dummy", "openai", "openai_batch", "deepseek", "ark", "ark_batch", "siliconflow_batch", "tencent_batch", "ali_batch", "gemini", "gemini_batch", "local_vllm", "dashscope"]
provider_list = get_args(Provider)

class GeneratorBase:
    def __init__(
        self,
        provider: Provider = "dummy",
        model: str | None = None,
        extra_model_paras: dict[str, Any] | None = None,
    ) -> None:
        """
        The base class for the sample generator using OpenAI API.
        Args:
            provider (Provider): The provider of the LLM API, either "openai", "openai_batch", "deepseek", "ark", "ark_batch", "siliconflow_batch", "tencent_batch", "ali_batch", "gemini" or "gemini_batch", "local_vllm". Default is "dummy", which does not provide any API functionality.
            model (str | None): The model name to use. If None, the default reasoning model of the provider will be used. For OpenAI, it is "o3"; for DeepSeek, it is "deepseek-reasoner"; for Ark, it is "ark-reasoner"; for SiliconFlow, it is "Pro/deepseek-ai/DeepSeek-R1".
            system_prompt_file (str | None): The file to save the system prompt. If the generator need system prompt, the path should be provided.
            extra_model_paras (dict[str, Any] | None): Extra model parameters to pass to the API. Default is None. For example, the gpt-5 model support the `reasoning` parameter to control the reasoning strength.
        """
        if provider not in provider_list:
            raise ValueError(
                f"Unsupported provider: {provider}. Supported providers are {provider_list}."
            )
        self.provider = provider
        
        # some providers use different environment variables for api key
        if self.provider == "deepseek":
            if os.getenv("DEEPSEEK_API_KEY") is not None:
                os.environ["OPENAI_API_KEY"] = os.getenv("DEEPSEEK_API_KEY") # type: ignore
            elif os.getenv("DEEPSEEK_OPENAI_API_KEY") is not None:
                os.environ["OPENAI_API_KEY"] = os.getenv("DEEPSEEK_OPENAI_API_KEY") # type: ignore
        if self.provider == "local_vllm":
            os.environ["OPENAI_BASE_URL"] = os.getenv("VLLM_HOST", "http://localhost:8000/v1") # type: ignore
            if os.getenv("LOCAL_VLLM_API_KEY") is not None:
                os.environ["OPENAI_API_KEY"] = os.getenv("LOCAL_VLLM_API_KEY") # type: ignore
            
        # Batch coordination system
        self.batch_requests = []
        self.batch_id = None
        self.pending_responses = defaultdict(list)  # Maps request_id to a list of asyncio.Future
        self.batch_coordinator_task = None
        self.last_request_time = None
        self.batch_upload_delay = 15  # seconds
        self._batch_lock = asyncio.Lock()
        
        # create the client
        if provider == "dummy":
            self.client = None
            self.model_name = "dummy"
            # Dummy provider does not support any model, so this is not applicable
        elif provider in ["openai", "openai_batch"]:
            self.client = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1"), timeout=1800)
            self.model_name = "o3" if model is None else model
            # Uncomment the following line to use a specific model version
            # self.model_name = "o3-2025-04-16"
        elif provider == "deepseek":
            self.client = AsyncOpenAI(
                api_key=os.getenv("DEEPSEEK_API_KEY"),
                base_url="https://api.deepseek.com",
            )
            self.model_name = "deepseek-reasoner" if model is None else model
            os.environ['MAX_NUM_PARALLEL_RESPONSE'] = '1'
            # deepseek does not support versioning, so this is not applicable
        elif provider == "ark":
            self.client = AsyncOpenAI(
                api_key=os.getenv("ARK_API_KEY"),
                base_url="https://ark.cn-beijing.volces.com/api/v3/",
                timeout = 1800,
            )
            self.model_name = "deepseek-r1-250528" if model is None else model
            os.environ['MAX_NUM_PARALLEL_RESPONSE'] = '1'
        elif provider == "ark_batch":
            self.client = AsyncArk(api_key=os.getenv("ARK_API_KEY"), timeout = 3600*24)
            self.model_name = os.getenv("ARK_BATCH_ENDPOINT")
            if self.model_name is None:
                raise ValueError("ARK_BATCH_ENDPOINT environment variable is not set.")
        elif provider == "siliconflow_batch":
            self.client = AsyncOpenAI(api_key=os.getenv("SILICONFLOW_API_KEY"), base_url="https://api.siliconflow.com/v1")
            self.model_name = "Pro/deepseek-ai/DeepSeek-R1" if model is None else model
        elif provider == "tencent_batch":
            self.client = AsyncOpenAI(api_key=os.getenv("TENCENT_API_KEY"), base_url="https://api.lkeap.cloud.tencent.com/v1")
            self.model_name = "deepseek-r1-0528" if model is None else model
            self.secret_id = os.environ.get("TENCENT_SECRET_ID")
            self.secret_key = os.environ.get("TENCENT_SECRET_KEY")
            self.region = os.environ.get("TENCENT_COS_REGION", "ap-beijing")
            self.bucket = os.environ.get("TENCENT_COS_BUCKET")

            config = CosConfig(Region=self.region, SecretId=self.secret_id, SecretKey=self.secret_key)
            self.cos_client = CosS3Client(config)
        elif provider == "ali_batch":
            self.client = AsyncOpenAI(api_key=os.getenv("ALI_API_KEY"), base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
            self.model_name = "deepseek-r1" if model is None else model
        elif provider == "gemini":
            self.client = AsyncOpenAI(api_key=os.getenv("GEMINI_API_KEY"), base_url="https://generativelanguage.googleapis.com/v1beta/openai/")
            os.environ['MAX_NUM_PARALLEL_RESPONSE'] = '8'
            self.model_name = "gemini-2.5-flash" if model is None else model
        elif provider == "gemini_batch":
            self.client = AsyncOpenAI(api_key=os.getenv("GEMINI_API_KEY"), base_url="https://generativelanguage.googleapis.com/v1beta/openai/")
            self.model_name = "gemini-2.5-flash" if model is None else model
            self.gemini_client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))
        elif provider == "local_vllm":
            self.client = AsyncOpenAI(api_key=os.getenv("LOCAL_VLLM_API_KEY"), base_url="http://localhost:8000/v1")
            assert model is not None, "For local_vllm provider, model must be specified."
            self.model_name = model
        elif provider == "dashscope":
            self.client = AsyncOpenAI(api_key=os.getenv("DASHSCOPE_API_KEY"), base_url="https://dashscope.aliyuncs.com/compatible-mode/v1")
            self.model_name = "qwen3-next-80b-a3b-thinking" if model is None else model
            
        logger.info(f"Successfully created client for provider: {self.provider}, model: {self.model_name}")
        
        if os.environ.get('MAX_NUM_PARALLEL_RESPONSE', None) is not None:
            logger.info(f"Using MAX_NUM_PARALLEL_RESPONSE={os.environ['MAX_NUM_PARALLEL_RESPONSE']}")
        self._init_token_usage()

    def _init_token_usage(self) -> None:
        self.total_usage = {
            "prompt_tokens": 0,
            "completion_tokens": 0,
            "prompt_cache_hit_tokens": 0,
            "prompt_cache_miss_tokens": 0,
            "total_tokens": 0,
            # the following is only used by openai
            # "prompt_tokens_details": {
            #     "cached_tokens": 0
            # },
            "completion_tokens_details": {
                "reasoning_tokens": 0,
                "accepted_prediction_tokens": 0,
                "rejected_prediction_tokens": 0,
            },
        }

    def update_token_usage(self, usage: CompletionUsage | None) -> None:
        if usage is None:
            return
        self.total_usage["prompt_tokens"] += getattr(usage, "prompt_tokens", 0)
        self.total_usage["completion_tokens"] += getattr(usage, "completion_tokens", 0)

        if hasattr(usage, "prompt_tokens_details"):
            cached_tokens = getattr(usage.prompt_tokens_details, "cached_tokens", 0)
            self.total_usage["prompt_cache_hit_tokens"] += cached_tokens
            self.total_usage["prompt_cache_miss_tokens"] += (
                getattr(usage, "prompt_tokens", 0) - cached_tokens
            )
        else:
            self.total_usage["prompt_cache_hit_tokens"] += getattr(
                usage, "prompt_cache_hit_tokens", 0
            )
            self.total_usage["prompt_cache_miss_tokens"] += getattr(
                usage, "prompt_cache_miss_tokens", 0
            )

        self.total_usage["total_tokens"] += getattr(usage, "total_tokens", 0)
        if hasattr(usage, "completion_tokens_details"):
            self.total_usage["completion_tokens_details"][
                "reasoning_tokens"
            ] += getattr(usage.completion_tokens_details, "reasoning_tokens", 0)
            accepted_prediction_tokens = getattr(
                usage.completion_tokens_details, "accepted_prediction_tokens", 0
            )
            self.total_usage["completion_tokens_details"][
                "accepted_prediction_tokens"
            ] += (
                accepted_prediction_tokens
                if isinstance(accepted_prediction_tokens, int)
                else 0
            )
            rejected_prediction_tokens = getattr(
                usage.completion_tokens_details, "rejected_prediction_tokens", 0
            )
            self.total_usage["completion_tokens_details"][
                "rejected_prediction_tokens"
            ] += (
                rejected_prediction_tokens
                if isinstance(rejected_prediction_tokens, int)
                else 0
            )
        logger.debug(f"Updated token usage:\n{json.dumps(self.total_usage, indent=2)}")

    def _resolve_response(
        self, response: ChatCompletion
    ) -> list[tuple[str, str] | None]:
        return_list: list[tuple[str, str] | None] = []
        for choice in response.choices:
            try:
                thinking = getattr(choice.message, "reasoning_content", "")
                if not isinstance(thinking, str):
                    thinking = ""
                answer = choice.message.content
                if not isinstance(answer, str):
                    answer = ""
                return_list.append((thinking, answer))
            except Exception as e:
                logger.error(f"Error in processing response: {e}")
                return_list.append(None)
        return return_list

    async def _single_turn_request_with_max_n(
        self,
        system_prompt: str,
        user_prompt: str,
        shared_semaphore: asyncio.Semaphore,
        num_returns: int = 1,
        max_tokens: int = 64_000,
        use_json: bool = False,
        extra_model_paras: dict[str, Any] | None = None,
    ) -> list[tuple[str, str] | None]:
        """
        Because some model does not support parameter `n` or has a small limit of `n`, here we request it n times.
        """
        # split num_returns into batches of size max_allow_n
        max_allow_n = os.getenv('MAX_NUM_PARALLEL_RESPONSE', None)
        if max_allow_n is not None:
            max_allow_n = int(max_allow_n)
        if max_allow_n is not None:
            each_batch = [max_allow_n] * (num_returns // max_allow_n)
            if num_returns % max_allow_n != 0:
                each_batch.append(num_returns % max_allow_n)
        else:
            each_batch = [num_returns]
            
        requests = [
            self._single_turn_request(
                system_prompt=system_prompt,
                user_prompt=user_prompt,
                shared_semaphore=shared_semaphore,
                num_returns=num_return,
                max_tokens=max_tokens,
                use_json=use_json,
                extra_model_paras=extra_model_paras,
            )
            for num_return in each_batch
        ]
        
        return_list = await asyncio.gather(*requests)
        return_list = [item for sublist in return_list for item in sublist]
        return return_list

    async def _single_turn_request(
        self,
        system_prompt: str,
        user_prompt: str,
        shared_semaphore: asyncio.Semaphore,
        num_returns: int = 1,
        max_tokens: int = 64_000,
        use_json: bool = False,
        extra_model_paras: dict[str, Any] | None = None,
    ) -> list[tuple[str, str] | None]:
        async with shared_semaphore:
            try:
                if self.client is None:
                    raise ValueError(
                        "You are tryting to call a dummy generator. Please use a real provider like 'openai' or 'deepseek'."
                    )
                if system_prompt == "" and os.getenv("NO_EMPTY_SYSTEM_PROMPT_WARNING", "0") != "1":
                    logger.warning("System prompt is empty. This may lead to unexpected results.")
                if system_prompt == "":
                    messages = [{"role": "user", "content": user_prompt}]
                else:
                    messages = [
                        {"role": "system", "content": system_prompt},
                        {"role": "user", "content": user_prompt},
                    ]
                async with self._batch_lock:
                    await asyncio.sleep(0.01)  # To avoid being rate limited
                response = await self.client.chat.completions.create(
                    model=self.model_name,
                    messages=messages,
                    max_completion_tokens=max_tokens,
                    n=num_returns,
                    response_format={"type": "json_object"} if use_json else NOT_GIVEN,
                    **(extra_model_paras if extra_model_paras is not None else {})
                ) # type: ignore
                self.update_token_usage(response.usage)
                return self._resolve_response(response=response)
            except Exception as e:
                logger.error(f"Error in single turn request: {e}")
                return [None for _ in range(num_returns)]

    async def _single_turn_request_ark_batch_single(
        self,
        system_prompt: str,
        user_prompt: str,
        shared_semaphore: asyncio.Semaphore,
        num_returns: int = 1,
        max_tokens: int = 64_000,
        use_json: bool = False,
    ) -> list[tuple[str, str] | None]:
        async with shared_semaphore:
            try:
                if self.client is None:
                    raise ValueError(
                        "You are tryting to call a dummy generator. Please use a real provider like 'openai' or 'deepseek'."
                    )
                response = await self.client.batch_chat.completions.create(
                    model=self.model_name,
                    messages=[
                        {
                            "role": "system",
                            "content": system_prompt
                        },
                        {
                            "role": "user",
                            "content": user_prompt
                        },
                    ],
                    max_tokens=max_tokens,
                    n=num_returns,
                    response_format={"type": "json_object"}
                    if use_json else None,
                )
                self.update_token_usage(response.usage)
                return self._resolve_response(response=response)
            except Exception as e:
                logger.error(f"Error in single turn request: {e}")
                return [None for _ in range(num_returns)]

    async def _single_turn_request_ark_batch(
        self,
        system_prompt: str,
        user_prompt: str,
        shared_semaphore: asyncio.Semaphore,
        num_returns: int = 1,
        max_tokens: int = 64_000,
        use_json: bool = False,
        extra_model_paras: dict[str, Any] | None = None,
    ) -> list[tuple[str, str] | None]:
        """
        Because current deepseek reasoner does not support parameter `n`, here we request it n times.
        """
        if extra_model_paras is not None:
            raise NotImplementedError("extra_model_paras is not supported in ark_batch mode yet.")
        requests = [
            self._single_turn_request_ark_batch_single(
                system_prompt=system_prompt,
                user_prompt=user_prompt,
                shared_semaphore=shared_semaphore,
                num_returns=1,
                max_tokens=max_tokens,
                use_json=use_json,
            ) for _ in range(num_returns)
        ]
        return_list = await asyncio.gather(*requests)
        return_list = [item for sublist in return_list for item in sublist]
        return return_list


    async def single_turn_request(
        self,
        system_prompt: str,
        user_prompt: str,
        shared_semaphore: asyncio.Semaphore,
        num_returns: int = 1,
        max_tokens: int = 64_000,
        use_json: bool = False,
        extra_model_paras: dict[str, Any] | None = None,
    ) -> list[tuple[str, str] | None]:
        """
        Single turn request by openai api. Return the list of (thinking, answer) pairs.
        Args:
            system_prompt (str):
            user_prompt (str):
            num_returns (int): how many response will be generated from the api, default 1.
            max_tokens (int): max token the LLM can generated. The thinking tokens are counted. Default is the maximal value of deepseek-r1 as 64,000.
            use_json (bool): whether to use json response format. Default is False.
            extra_model_paras (dict[str, Any] | None): Extra model parameters to pass to the API.
        Returns:
            list[tuple[str, str]]: the list of (thinking, answer) pairs. The length is equal to `num_returnes`.
        """
        if self.provider == "ark_batch":
            return await self._single_turn_request_ark_batch(
                system_prompt=system_prompt,
                user_prompt=user_prompt,
                shared_semaphore=shared_semaphore,
                num_returns=num_returns,
                max_tokens=max_tokens,
                use_json=use_json,
                extra_model_paras=extra_model_paras,
            )
        elif self.provider == "openai_batch" or self.provider == "siliconflow_batch" or self.provider == "tencent_batch" or self.provider == "ali_batch" or self.provider == "gemini_batch":
            # For batch mode, store the request and return a future that will be resolved
            # when the batch processing is complete
            return await self._single_turn_request_openai_batch(
                system_prompt=system_prompt,
                user_prompt=user_prompt,
                shared_semaphore=shared_semaphore,
                num_returns=num_returns,
                max_tokens=max_tokens,
                use_json=use_json,
                extra_model_paras=extra_model_paras,
            )
        else:
            return await self._single_turn_request_with_max_n(
                system_prompt=system_prompt,
                user_prompt=user_prompt,
                shared_semaphore=shared_semaphore,
                num_returns=num_returns,
                max_tokens=max_tokens,
                use_json=use_json,
                extra_model_paras=extra_model_paras,
            )

    async def _single_turn_request_openai_batch(
        self,
        system_prompt: str,
        user_prompt: str,
        shared_semaphore: asyncio.Semaphore,
        num_returns: int = 1,
        max_tokens: int = 64_000,
        use_json: bool = False,
        extra_model_paras: dict[str, Any] | None = None,
    ) -> list[tuple[str, str] | None]:
        """
        Submit requests to temporary batch storage and wait for batch processing to complete.
        This method stores the request and starts a coordinator if needed, then waits for results.
        """
        try:
            async with self._batch_lock:
                # Create unique request IDs for this call
                key = f"{system_prompt}___{user_prompt}"
                hash_val = sha512(key.encode('utf-8')).hexdigest()
                request_id = f"req_{hash_val}"
                response_futures = []
                # Store requests in batch queue
                for i in range(num_returns):
                    custom_id = f"{request_id}__{i}"
                    batch_request = {
                        "custom_id": custom_id,
                        "method": "POST",
                        "url": "/v1/chat/completions",
                        "body": {
                            "model": self.model_name,
                            "messages": [
                                {"role": "system", "content": system_prompt},
                                {"role": "user", "content": user_prompt},
                            ],
                            "max_completion_tokens": max_tokens,
                        }
                    }
                    
                    if use_json:
                        batch_request["body"]["response_format"] = {"type": "json_object"}
                    
                    if extra_model_paras is not None:
                        batch_request["body"].update(extra_model_paras)
                    
                    self.batch_requests.append(batch_request)
                    
                    # Create a future for this request
                    future = asyncio.Future()
                    self.pending_responses[request_id].append(future)
                    response_futures.append(future)
                
                # Update last request time
                self.last_request_time = time.time()
                
                # Start batch coordinator if not already running
                if self.batch_coordinator_task is None or self.batch_coordinator_task.done():
                    self.batch_coordinator_task = asyncio.create_task(self._batch_coordinator())
            
            # Wait for all responses to complete
            logger.debug(f"Waiting for {len(response_futures)} batch responses...")
            results = await asyncio.gather(*response_futures, return_exceptions=True)
            
            # Convert results to expected format
            processed_results = []
            for result in results:
                if isinstance(result, Exception):
                    logger.error(f"Error in batch response: {result}")
                    processed_results.append(None)
                else:
                    processed_results.append(result)
            
            return processed_results
            
        except Exception as e:
            logger.error(f"Error in batch request: {e}")
            return [None for _ in range(num_returns)]

    async def _batch_coordinator(self):
        """
        Coordinates batch uploads. Waits 15 seconds after the last request, then uploads and processes.
        """
        try:
            while True:
                # Wait for the upload delay
                await asyncio.sleep(self.batch_upload_delay)
                
                async with self._batch_lock:
                    # Check if we should upload (no new requests in the last 15 seconds)
                    if (self.last_request_time is None or 
                        time.time() - self.last_request_time >= self.batch_upload_delay):
                        
                        if self.batch_requests:
                            logger.info(f"Uploading batch with {len(self.batch_requests)} requests after {self.batch_upload_delay}s delay")
                            
                            # Process the batch
                            batch_results = await self._upload_and_process_batch()
                            
                            # Dispatch results to waiting futures
                            for request_id, results in batch_results.items():
                                if request_id in self.pending_responses:
                                    futures = self.pending_responses[request_id]
                                    for future, result in zip(futures, results):
                                        if not future.done():
                                            future.set_result(result)

                            # Handle any missing results
                            for request_id, futures in self.pending_responses.items():
                                for i, future in enumerate(futures):
                                    if not future.done():
                                        logger.warning(f"No result received for {request_id} - Request {i}")
                                        future.set_result(None)

                            # Clear batch state
                            self.batch_requests = []
                            self.pending_responses = defaultdict(list)
                            self.batch_id = None
                            self.last_request_time = None
                            
                            logger.info("Batch processing completed and results dispatched")
                            break
                        else:
                            # No requests to process, continue waiting
                            continue
                    else:
                        # New requests came in, continue waiting
                        continue
                        
        except Exception as e:
            logger.error(f"Error in batch coordinator: {e}")
            # Set error results for all pending futures
            async with self._batch_lock:
                for futures in self.pending_responses.values():
                    for future in futures:
                        if not future.done():
                            future.set_result(None)
                self.pending_responses = defaultdict(list)

    async def _upload_batch_file_openai(self, batch_file_path: str) -> str:
        """
        Upload a batch file to the API and return the file ID.
        
        Args:
            batch_file_path (str): Path to the batch file to upload
            
        Returns:
            str: The uploaded file ID
            
        Raises:
            Exception: If upload fails
        """
        
        logger.info(f"Uploading batch file: {batch_file_path}")
        with open(batch_file_path, 'rb') as f:
            file_response = await self.client.files.create(
                file=f,
                purpose='batch'
            )
        
        logger.info(f"File uploaded successfully. File ID: {file_response.id}")
        return file_response.id
    
    def _upload_batch_file_tencent(self, batch_file_path: str) -> str:
        """
        Upload a batch file to the Tencent COS and return the file ID.

        Args:
            batch_file_path (str): Path to the batch file to upload

        Returns:
            str: The uploaded file ID

        Raises:
            Exception: If upload fails
        """
        
        if not self.cos_client:
            raise ValueError("Tencent COS client not available")

        # Generate a unique key using current timestamp
        timestamp = datetime.now().strftime('%Y%m%d_%H%M%S_%f')
        key = f"input/request_{timestamp}.jsonl"
        
        logger.info(f"Uploading batch file to Tencent COS: {batch_file_path} with key {key}")
        
        response = self.cos_client.upload_file(Bucket=self.bucket,Key=key,LocalFilePath=batch_file_path)

        if response.get('ETag') is None:
            raise Exception(f"Failed to upload file to Tencent COS: {response}")

        file_id = f"https://{self.bucket}.cos.{self.region}.myqcloud.com/{key}"
        logger.info(f"File uploaded successfully to Tencent COS. File ID: {file_id}")
        return file_id
    
    def _upload_batch_file_gemini(self, batch_file_path: str) -> str:
        """
        Upload a batch file to the Gemini API and return the file ID.
        
        Args:
            batch_file_path (str): Path to the batch file to upload
            
        Returns:
            str: The uploaded file ID
            
        Raises:
            Exception: If upload fails
        """
        
        logger.info(f"Uploading batch file: {batch_file_path}")
        uploaded_file = self.gemini_client.files.upload(
            file=batch_file_path,
            config={"display_name": f"req_{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}", "mime_type": "jsonl"},
        )
        
        logger.info(f"File uploaded successfully. File ID: {uploaded_file.name}")
        return uploaded_file.name

    async def upload_batch_file(self, batch_file_path: str) -> str:
        """
        Upload a batch file to the API and return the file ID.

        Args:
            batch_file_path (str): Path to the batch file to upload

        Returns:
            str: The uploaded file ID

        Raises:
            Exception: If upload fails
        """
        if self.provider == "openai_batch" or self.provider == "siliconflow_batch" or self.provider == "ali_batch":
            return await self._upload_batch_file_openai(batch_file_path)
        elif self.provider == "tencent_batch":
            return self._upload_batch_file_tencent(batch_file_path)
        elif self.provider == "gemini_batch":
            return self._upload_batch_file_gemini(batch_file_path)
        else:
            raise ValueError(f"Unknown provider: {self.provider}")

    async def _upload_and_process_batch(self) -> dict[str, list[tuple[str, str]]]:
        """
        Upload all accumulated batch requests to OpenAI and wait for completion.
        Returns a dictionary mapping request_id to a list of (thinking, answer) tuples.
        """
        if not self.batch_requests:
            logger.info("No batch requests to process")
            return {}
        
        if self.client is None:
            raise ValueError("OpenAI client is not initialized for batch processing")
        
        logger.info(f"Processing {len(self.batch_requests)} batch requests")

        if self.batch_id is not None:
            logger.info(f"Batch already in progress with ID: {self.batch_id}, waiting for completion")
            return await self._wait_for_batch_completion()
        
        # Create temporary file with batch requests
        with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.jsonl', encoding='utf-8') as f:
            for request in self.batch_requests:
                f.write(json.dumps(request, ensure_ascii=False) + '\n')
            batch_file_path = f.name
        
        try:
            # Upload the file
            file_id = await self.upload_batch_file(batch_file_path)
            
            # Create batch
            batch_response = await self.client.batches.create(
                input_file_id=file_id,
                endpoint="/v1/chat/completions",
                completion_window="24h",
                metadata={
                    "description": "Math olympiad batch processing",
                    "timestamp": datetime.now().isoformat(),
                    "request_count": str(len(self.batch_requests))
                }
            )
            
            self.batch_id = batch_response.id
            logger.info(f"Batch created successfully. Batch ID: {self.batch_id}")
            
            # Poll for completion
            return await self._wait_for_batch_completion()
            
        except Exception as e:
            logger.error(f"Error during batch upload/processing: {e}")
            raise
            
        finally:
            # Clean up temporary file
            try:
                if os.path.exists(batch_file_path):
                    os.unlink(batch_file_path)
                    logger.debug(f"Cleaned up temporary file: {batch_file_path}")
            except Exception as e:
                logger.warning(f"Failed to delete temporary file {batch_file_path}: {e}")

    async def _wait_for_batch_completion(self) -> dict[str, list[tuple[str, str]]]:
        """
        Wait for batch to complete and retrieve results.
        """
        if not self.batch_id or not self.client:
            raise ValueError("Batch ID or client not available")
        
        logger.info(f"Waiting for batch {self.batch_id} to complete...")
        start_time = time.time()
        check_count = 0
        
        while True:
            try:
                check_count += 1
                batch_status = await self.client.batches.retrieve(self.batch_id)
                elapsed_time = time.time() - start_time
                
                logger.info(f"Batch status check #{check_count}: {batch_status.status} (elapsed: {elapsed_time/60:.1f} min)")
                
                logger.debug(f"Batch details: {batch_status}")
                
                if batch_status.status == "completed":
                    logger.info(f"Batch completed successfully after {elapsed_time/60:.1f} minutes")
                    if not batch_status.output_file_id:
                        logger.error("Batch completed but no output file ID provided")
                        return {}
                    return await self._download_batch_results(batch_status.output_file_id)
                    
                elif batch_status.status == "failed":
                    error_msg = f"Batch failed: {batch_status.errors if hasattr(batch_status, 'errors') else 'Unknown error'}"
                    logger.error(error_msg)
                    return {}
                    
                elif batch_status.status in ["cancelled", "expired"]:
                    logger.error(f"Batch was {batch_status.status} after {elapsed_time/60:.1f} minutes")
                    if not batch_status.output_file_id:
                        logger.error(f"Batch was {batch_status.status} but no output file ID provided.")
                        return {}
                    return await self._download_batch_results(batch_status.output_file_id)
                    
                elif batch_status.status in ["validating", "in_progress", "finalizing"]:
                    # Still processing, wait and check again
                    # log number of completed requests
                    completed_requests = getattr(batch_status.request_counts,'completed', 0)
                    total_requests = getattr(batch_status.request_counts,'total', 0)
                    logger.info(f"Completed {completed_requests} / {total_requests} requests")
                    await asyncio.sleep(60)  # Wait 1 minute before checking again
                    
                else:
                    logger.warning(f"Unknown batch status: {batch_status.status}")
                    await asyncio.sleep(60)
                    
            except Exception as e:
                logger.error(f"Error checking batch status: {e}")
                await asyncio.sleep(60)  # Wait before retrying

    async def _download_output_file(self, output_file_id: str) -> str:
        """
        Download and batch results for OpenAI api.
        
        Args:
            output_file_id (str): The ID of the output file to download.

        Returns:
            str: The content of the retrieved file.
        """
        if not self.client:
            raise ValueError("OpenAI client not available")
        
        logger.info(f"Downloading output file with ID: {output_file_id}")
        
        # Download the results file
        file_response = await self.client.files.content(output_file_id)
        results_content = file_response.content.decode('utf-8')
        logger.info(f"Successfully downloaded output file with ID: {output_file_id}")
        return results_content
    
    def _download_output_file_tencent(self, output_file_url: str) -> str:
        """
        Download an output file from a direct url provided by Tencent.
        """
        logger.info(f"Downloading output file from {output_file_url}")

        # Download the file using the requests library
        response = requests.get(output_file_url)
        if response.status_code == 200:
            logger.info(f"Successfully downloaded output file from {output_file_url}")
            return response.text
        else:
            logger.error(f"Failed to download output file from {output_file_url}: {response.status_code}")
            logger.debug(f"Response content: {response.text}")
            raise ValueError("Failed to download output file")
    
    def _download_output_file_gemini(self, output_file_id: str) -> str:
        """
        Download an output file from Gemini.
        """
        logger.info(f"Downloading output file from Gemini with ID: {output_file_id}")

        # Download the file using the Gemini client
        response = self.gemini_client.files.download(file=output_file_id).decode('utf-8')
        return response

    async def download_output_file(self, output_file_id: str) -> str:
        """
        Download and batch results.
        
        Args:
            output_file_id (str): The ID of the output file to download.

        Returns:
            str: The content of the retrieved file.
        """
        if self.provider == "openai_batch" or self.provider == "siliconflow_batch" or self.provider == "ali_batch":
            return await self._download_output_file(output_file_id)
        elif self.provider == "tencent_batch":
            return self._download_output_file_tencent(output_file_id)
        elif self.provider == "gemini_batch":
            return self._download_output_file_gemini(output_file_id)
        else:
            raise ValueError(f"Provider {self.provider} does not support batch result downloading.")

    async def _download_batch_results(self, output_file_id: str) -> dict[str, list[tuple[str, str]]]:
        """
        Download and parse batch results.
        """
        
        try:
            results_content = await self.download_output_file(output_file_id)
            
            # Parse results
            results = defaultdict(list)  # Maps request_id to list of (thinking, answer) or None
            total_prompt_tokens = 0
            total_completion_tokens = 0
            total_tokens = 0
            total_reasoning_tokens = 0
            total_cached_tokens = 0

            for line in results_content.strip().split('\n'):
                if not line:
                    continue
                    
                try:
                    result = json.loads(line)
                    custom_id = result['custom_id']
                    request_id = custom_id.split('__')[0]
                    
                    if result.get('error'):
                        logger.error(f"Error in batch result for {custom_id}: {result['error']}")
                        continue
                    
                    response = result['response']
                    if response.get('body', {}).get('choices'):
                        choice = response['body']['choices'][0]
                        thinking = getattr(choice.get('message', {}), 'reasoning_content', "")
                        if not isinstance(thinking, str):
                            thinking = ""
                        answer = choice.get('message', {}).get('content', "")
                        if not isinstance(answer, str):
                            answer = ""
                        results[request_id].append((thinking, answer))

                        # Accumulate usage stats
                        if response['body'].get('usage'):
                            usage = response['body']['usage']
                            total_prompt_tokens += usage.get('prompt_tokens', 0)
                            total_completion_tokens += usage.get('completion_tokens', 0)
                            total_tokens += usage.get('total_tokens', 0)
                            
                            # Handle reasoning tokens if available
                            if usage.get('completion_tokens_details', {}).get('reasoning_tokens'):
                                total_reasoning_tokens += usage['completion_tokens_details']['reasoning_tokens']
                            
                            # Handle cached tokens if available
                            if usage.get('prompt_tokens_details', {}).get('cached_tokens'):
                                total_cached_tokens += usage['prompt_tokens_details']['cached_tokens']
                    else:
                        logger.error(f"No choices in batch result for {custom_id}")
                        
                except Exception as e:
                    logger.error(f"Error parsing batch result line: {line[:100]}... Error: {e}")
                    continue
            
            # Create usage object and update totals
            if total_tokens > 0:

                completion_tokens_details = CompletionTokensDetails(
                    reasoning_tokens=total_reasoning_tokens
                ) if total_reasoning_tokens > 0 else None

                prompt_tokens_details = PromptTokensDetails(
                    cached_tokens=total_cached_tokens
                ) if total_cached_tokens > 0 else None

                usage = CompletionUsage(
                    prompt_tokens=total_prompt_tokens,
                    completion_tokens=total_completion_tokens,
                    total_tokens=total_tokens,
                    completion_tokens_details=completion_tokens_details,
                    prompt_tokens_details=prompt_tokens_details,
                )
                
                self.update_token_usage(usage)

            logger.info(f"Successfully processed {sum(len(v) for v in results.values())} batch results")
            return results
            
        except Exception as e:
            logger.error(f"Error downloading batch results: {e}")
            return {}


class SystemPromptMixin(ABC):
    """
    A mixin class to provide system prompt for the generator.
    Sub-classes should implement the `system_prompt_file` property to return the path to the system prompt file.
    """

    @property
    @abstractmethod
    def system_prompt_file(self) -> str:
        """
        The path to the system prompt file. This should be implemented by the sub-class.
        """
        ...

    @cached_property
    def _system_prompt(self) -> str:
        if getattr(self, "system_prompt_file", None) is None:
            raise ValueError(
                "system_prompt_file should be provided before call the system prompt."
            )
        assert hasattr(self, "system_prompt_file")
        path = Path(self.system_prompt_file)  # type: ignore[attr-defined]
        if not path.exists():
            raise FileNotFoundError(
                f"System prompt file not found in default path `{str(path)}`. Please provide a valid path."
            )
        with open(path, "r", encoding="UTF-8") as f:
            return f.read()


import argparse


def common_parse_args() -> argparse.ArgumentParser:
    """
    For each sub-class, you can define your onw args.
    """

    parser = argparse.ArgumentParser(add_help=False)
    parser.add_argument(
        "provider",
        type=str,
        choices=provider_list,
        help="The api provider.",
    )
    parser.add_argument(
        "--file",
        type=str,
        required=True,
        help="Path to the file containing problem-proof pairs. The file is expected as a jsonl file.",
    )
    parser.add_argument(
        "-m",
        "--model",
        type=str,
        default=None,
        help="The model name, if None, default is the newest `o3` version for openai and the newest `r1` version for deepseek.",
    )
    parser.add_argument(
        "-l",
        "--lines",
        type=int,
        default=None,
        help="Number of lines to read from the file. If None, read all lines.",
    )
    parser.add_argument(
        "--num_returns", type=int, default=1, help="Number of returns for each request."
    )
    parser.add_argument(
        "--num_worker",
        type=int,
        default=1,
        help="Number of workers to use for async api",
    )
    parser.add_argument(
        "--max_tokens",
        type=int,
        default=None,
        help="Max tokens for the LLM to generate. Default is None.",
    )
    parser.add_argument(
        "--save_path",
        type=str,
        default=None,
        help="Path to save the generated thinking steps and solutions. Default: `./save/xxx`",
    )
    parser.add_argument(
        "--resume",
        action="store_true",
        help="If True, pass the existing files. Otherwise, overwrite the existing files.",
    )
    parser.add_argument(
        "--no_async",
        action="store_true",
        help="If True, use async mode to process the requests. Otherwise, use sync mode.",
    )
    parser.add_argument(
        "--only_solve_id_conflict",
        action="store_true",
        help="If True, only solve the id conflict problems. Otherwise, generate all problems.",
    )
    parser.add_argument(
        "--indexes",
        type=int,
        nargs="*",
    )
    parser.add_argument(
        "--batch_id",
        type=str,
        default=None,
        help="The batch ID for the requests."
    )
    parser.add_argument(
        "--extra_model_paras",
        type=eval,
        default=None,
        help="Extra model parameters to pass to the API in dictionary format. For example, the gpt-5 model support the `reasoning` parameter to control the reasoning strength. Example: `--extra_model_paras \"{'reasoning': 'high'}\"`",
    )
    return parser
