from typing import List, Dict
import os

from tqdm.asyncio import tqdm_asyncio
from tqdm.auto import tqdm
from aiolimiter import AsyncLimiter
from openai import AsyncOpenAI

from .base_model import BaseModel


class AsyncOpenAILLM:
    def __init__(
        self,
        model_path,
        batch_size: int = 100,
        requests_per_minute: int = 100,
        api_base: str = None,
    ):
        if 'openai' in model_path:
            self.client = AsyncOpenAI(
                #api_key=os.environ.get("OPENAI_API_KEY")
                api_key=os.environ.get("OPENROUTER_API_KEY"),
                base_url="https://openrouter.ai/api/v1",
            )
            self.model_path = model_path.replace('openai/', '')
        elif 'xai' in model_path:
            self.client = AsyncOpenAI(
                api_key=os.environ.get("XAI_API_KEY"),
                base_url="https://api.x.ai/v1",
            )
            self.model_path = model_path.replace('xai/', '')
        elif 'deepseek-r1' in model_path:
            self.client = AsyncOpenAI(
                api_key=os.environ.get("OPENROUTER_API_KEY"),
                base_url="https://openrouter.ai/api/v1",
            )
            self.model_path = model_path
        elif 'qwq-32b' in model_path:
            self.client = AsyncOpenAI(
                api_key=os.environ.get("OPENROUTER_API_KEY"),
                base_url="https://openrouter.ai/api/v1",
            )
            self.model_path = model_path
        elif 'qwen3-32b' in model_path:
            self.client = AsyncOpenAI(
                api_key=os.environ.get("OPENROUTER_API_KEY"),
                base_url="https://openrouter.ai/api/v1",
            )
            self.model_path = model_path

        self.batch_size = batch_size  # Define batch size for batch processing
        self.requests_per_minute = (
            requests_per_minute  # Rate limit: 100 requests per minute
        )
        self.limiter = AsyncLimiter(
            self.requests_per_minute, 60
        )  # Set up the rate limiter
        
    def validate_openai(self):
        return True

    async def _get_completion_text_async(self, messages: List[Dict[str, str]], **kwargs) -> str:
        # wrap a single async API call and extract the content
        async with self.limiter:  # Apply rate limiting
            try:
                resp = await self.client.chat.completions.create(
                    model=self.model_path,
                    messages=messages,
                    **kwargs
                )
                #print(resp, self.model_path)
                
                if self.model_path in ['grok-3-mini-beta']:
                    return {
                        'output': resp.choices[0].message.content,
                        'prompt_tokens': resp.usage.prompt_tokens,
                        'completion_tokens': resp.usage.completion_tokens,
                        'reasoning_tokens': resp.usage.completion_tokens_details.reasoning_tokens,
                        'reasoning_output': resp.choices[0].message.reasoning_content
                    }
                elif 'deepseek-r1' in self.model_path:
                    #print(resp)
                    #print(resp, self.model_path)
                    return {
                        'output': resp.choices[0].message.content,
                        'prompt_tokens': resp.usage.prompt_tokens,
                        'completion_tokens': resp.usage.completion_tokens,
                        'reasoning_output': resp.choices[0].message.reasoning
                    }
                elif 'qwq-32b' in self.model_path:
                    
                    return {
                        'output': resp.choices[0].message.content,
                        'prompt_tokens': resp.usage.prompt_tokens,
                        'completion_tokens': resp.usage.completion_tokens,
                        'reasoning_output': resp.choices[0].message.reasoning
                    }
                elif 'qwen3-32b' in self.model_path:
                    return {
                        'output': resp.choices[0].message.content,
                        'reasoning_output': resp.choices[0].message.reasoning
                    }
                elif self.model_path in [
                    'o1-mini', 'o1',
                    'o3-mini', 'o3'
                ]:
                    #try:
                    return {
                        'output': resp.choices[0].message.content,
                        'prompt_tokens': resp.usage.prompt_tokens,
                        'completion_tokens': resp.usage.completion_tokens,
                        'reasoning_tokens': resp.usage.completion_tokens_details.reasoning_tokens
                    }
                    #except:
                        # return {
                        #     'output': '',
                        #     'prompt_tokens': 0,
                        #     'completion_tokens': 0,
                        #     'reasoning_tokens': 0
                        # }
                else:
                    raise ValueError("Wrong model name")
            except Exception as e:
                print(f"Error during OpenAI API call: {e}")
                if self.model_path in ['grok-3-mini-beta']:
                    
                    return {
                        'output': '',
                        'prompt_tokens': 0,
                        'completion_tokens': 0,
                        'reasoning_tokens': 0,
                        'reasoning_output': ''
                    }
                elif 'deepseek-r1' in self.model_path:
                    
                    return {
                        'output': '',
                        'prompt_tokens': 0,
                        'completion_tokens': 0,
                        'reasoning_output': ''
                    }
                elif 'qwq-32b' in self.model_path:
                    
                    return {
                        'output': '',
                        'prompt_tokens': 0,
                        'completion_tokens': 0,
                        'reasoning_output': ''
                    }
                elif self.model_path in [
                    'o1-mini', 'o1',
                    'o3-mini', 'o3'
                ]:
                    return {
                        'output': '',
                        'prompt_tokens': 0,
                        'completion_tokens': 0,
                        'reasoning_tokens': 0
                    }
    
    async def completions(self, messages, **kwargs):
        """Generate completions for a list of messages using asynchronous batch processing."""
        assert isinstance(messages, list)  # Ensure messages are provided as a list
        assert all(
            isinstance(message, list) for message in messages
        ), "Message format error."
        assert all(
            isinstance(msg, dict) and set(msg.keys()) == {"role", "content"}
            for message in messages
            for msg in message
        ), "Message format error."

        result_responses = []
        # Process the messages in batches with progress visualization
        for start_idx in tqdm(
            range(0, len(messages), self.batch_size), desc="Processing batches"
        ):
            end_idx = start_idx + self.batch_size
            batch_prompts = messages[start_idx:end_idx]

            # Fetch responses for all prompts in the current batch asynchronously
            batch_responses = await tqdm_asyncio.gather(
                *[
                    self._get_completion_text_async(prompt, **kwargs)
                    for prompt in batch_prompts
                ]
            )
            result_responses.extend(batch_responses)

        return result_responses