import httpx
import time
from typing import Union, List
from abc import ABC, abstractmethod
import asyncio

from rich import print
class MultipleGenerationError(Exception):
    pass


class ModelResponseBase(ABC):
    def __init__(
        self,
        name: str,
        in_token_costs: float,
        out_token_costs: float,
        api_config: dict,
        max_retries: int = 100,
    ):
        self.name = name
        self.in_token_costs = in_token_costs
        self.out_token_costs = out_token_costs
        self.max_retries = max_retries

        if "extra_headers" in api_config:
            self.extra_headers = api_config["extra_headers"]
        else:
            self.extra_headers = None

    @property
    @abstractmethod
    def client(self):
        """Subclasses must define their own client."""
        pass

    @client.setter
    @abstractmethod
    def client(self, value):
        pass

    async def create_response(self, messages, **kwargs):
        _ = kwargs.pop("role", None)
        # Cap max_tokens to a safe value (e.g., 4096)
        if "max_tokens" in kwargs:
            kwargs["max_tokens"] = min(kwargs["max_tokens"], 4096)
        # print(kwargs, self.extra_headers)
        resp = await self.client.chat.completions.create(
            model=self.name,
            messages=messages,  # messages should be a list of dicts
            extra_headers=self.extra_headers,
            **kwargs,
        )
        # print(resp)
        return resp

    @staticmethod
    def build_prompt(role, messages, timestep, template_func):
        """Builds the prompt for the agent, including a temporary system message."""
        prompt = template_func(messages)
        return prompt

    async def _query_single_response(self, messages, **kwargs):
        response = None
        sleep_ms = 1_000          # initial pause in milliseconds
        max_sleep_ms = 32_000    # cap so we don't wait forever
        
        for i in range(self.max_retries):
            if i > 0:
                print(f"Retrying... {i}")
                # exponential back-off sleep
                await asyncio.sleep(sleep_ms / 1000)   # convert ms → seconds
                sleep_ms = min(sleep_ms * 2, max_sleep_ms)

            try:
                response = await self.create_response(messages, **kwargs)
            except Exception as e:
                print("API response error: ", e)
                print("Check if you are using the correct api profile for this model")
                continue

            if response.choices is None:
                print("Error in response\n", response)
            else:
                break

        if response is None or response.choices is None or len(response.choices) == 0:
            raise TimeoutError("No valid response", response)

        num_in_tokens = 0
        num_out_tokens = 0
        if hasattr(response, "usage") and response.usage is not None:
            num_in_tokens = response.usage.prompt_tokens
            num_out_tokens = response.usage.completion_tokens
        cost = (
            self.in_token_costs * num_in_tokens + self.out_token_costs * num_out_tokens
        ) / 1e6
        provider = None
        if hasattr(response, "provider") and response.provider is not None:
            provider = response.provider

        return response, (num_in_tokens, num_out_tokens, cost, provider)

    @abstractmethod
    def _extract_content(self, response):
        pass

    async def query_response(
        self, messages: list, return_extras: bool = True, **kwargs
    ) -> str:

        start_time = time.time()
        provider = None
        continue_with_empty_reply=kwargs.pop("continue_with_empty_reply", False)

        # Helper to process a single message list
        async def process_single(m, role):
            response, (in_tokens, out_tokens, cost, provider) = await self._query_single_response(
                m, role=role, **kwargs
            )

            tmp_kwargs = kwargs.copy()
            # tmp_kwargs["max_tokens"] = min(tmp_kwargs.get("max_tokens", 4096) + in_tokens, 4096)
            text = self._extract_content(response.choices[0])["text"]
            if (text is None or text.isspace()) and continue_with_empty_reply:
                # NOTE: for cloud reasoning model or for local model using prefill logic. It's possible to get empty reply
                #       in the case of reasoning model, this is due to not providing enough max_completion_tokens that exceeds the number of tokens used for the reasoning traces.
                #       in the case of prefilling logic, this is due to a thinker has finished its response earlier and provides EOS token.
                print(f".... didn't get any text in response. Continue anyways")
                response.choices[0].message.content = ""
                return response, (in_tokens, out_tokens, cost, provider)

            max_retries = 100
            sleep_ms = 1_000          # initial pause in milliseconds
            max_sleep_ms = 32_000    # cap so we don’t wait forever
            while (not text or text.isspace()) and max_retries > 0:
                response, (in_tokens, out_tokens, cost, provider) = (
                    await self._query_single_response(m, role=role, **tmp_kwargs)
                )
                # tmp_kwargs["max_tokens"] = min(tmp_kwargs.get("max_tokens", 4096) * 2, 4096)
                text = self._extract_content(response.choices[0])["text"]
                max_retries -= 1
                if not text or text.isspace():
                    print(f"====={max_retries}: sleeps {sleep_ms} ms =====\n{response}")

                    # exponential back-off sleep
                    await asyncio.sleep(sleep_ms / 1000)   # convert ms → seconds
                    sleep_ms = min(sleep_ms * 2, max_sleep_ms)    

            if max_retries == 0:
                raise RuntimeError(f"Didn't get any response from the model {self.name}")
                
            return response, (in_tokens, out_tokens, cost, provider)

        # If messages is a list of lists, batch process
        if isinstance(messages, list) and messages and isinstance(messages[0], list):
            roles = kwargs.pop("role", [None] * len(messages))
            # Gather all responses in parallel
            results = await asyncio.gather(
                *(process_single(m, role) for m, role in zip(messages, roles))
            )
            responses, stats = zip(*results)
            num_completions = kwargs.get("n", 1)
            # Extract text(s) for each message list
            if num_completions == 1:
                response_content = [
                    self._extract_content(r.choices[0]) for r in responses
                ]
            else:
                response_content = [
                    [
                        self._extract_content(r.choices[j])
                        for j in range(num_completions)
                    ]
                    for r in responses
                ]
            time_taken = time.time() - start_time
            if return_extras:
                # Sum tokens/costs across all prompts
                in_tokens = sum(s[0] for s in stats)
                out_tokens = sum(s[1] for s in stats)
                cost = sum(s[2] for s in stats)
                return response_content, (in_tokens, out_tokens, cost, time_taken)
            else:
                return response_content
        else:
            # Get initial response
            role = kwargs.pop("role", None)
            response, (in_tokens, out_tokens, cost, provider) = await process_single(messages, role)

            num_completions = kwargs.get("n", 1)
            if num_completions == 1:
                response_content = self._extract_content(response.choices[0])

            else:
                if num_completions > len(response.choices):
                    print(
                        f"Warning: {self.name} doesn't support multi-generation, running loop for n>1"
                    )
                    response_content = [response.choices[0].text]
                    in_tokens, out_tokens, cost = 0, 0, 0
                    for _ in range(1, num_completions):
                        tmp_response, (tmp_in, tmp_out, tmp_cost, tmp_provider) = (
                            await self._query_single_response(
                                messages, role=role, **kwargs
                            )
                        )
                        response_content.append(
                            self._extract_content(tmp_response.choices[0])
                        )
                        in_tokens += tmp_in
                        out_tokens += tmp_out
                        cost += tmp_cost
                else:
                    response_content = [
                        self._extract_content(response.choices[j])
                        for j in range(num_completions)
                    ]

            time_taken = time.time() - start_time

            if return_extras:
                return [response_content], {
                    "in_tokens": in_tokens,
                    "out_tokens": out_tokens,
                    "cost": cost,
                    "time_taken": time_taken,
                    "provider": provider
                }
            else:
                return [response_content], None
