import os
import asyncio
import warnings
from typing import Literal
from copy import deepcopy
from core.policy.schema import Policy
from utils.io_utils import dump_file

from safetytooling.apis import InferenceAPI
from safetytooling.data_models import (
    ChatMessage,
    MessageRole,
    Prompt,
    LLMResponse,
)


class APIModel(Policy):
    """A model that is hosted on an API."""

    model_provider: str
    model_name: str
    system_prompt: str | None

    # shared API instance and across all instances of this class
    shared_api = InferenceAPI(together_num_threads=5)

    # shared parameters for all instances of this class
    temperature = float(os.getenv("TEMPERATURE", 0.1))
    presence_penalty = float(os.getenv("PRESENCE_PENALTY", 0.0))

    def __init__(
        self,
        model_name: str,
        colloquial_name: str | None = None,
        system_prompt: str | None = None,
        model_provider: Literal[
            "auto", "openai", "anthropic", "google", "together", "deepseek"
        ] = "auto",
    ):
        """
        Instantiate an API model.

        :param model_name: The name of the model, as it appears in the provider's API.
        :type model_name: str
        :param colloquial_name: The colloquial name of the model, defaults to `model_name`.
        :type colloquial_name: str | None
        :param model_provider: The provider of the model.
        :type model_provider: Literal["openai", "anthropic", "google", "together", "deepseek"]
        """
        super().__init__(colloquial_name or model_name)
        self.model_provider = model_provider
        self.model_name = model_name
        self.system_prompt = system_prompt
        
        if 'gemini' in self.model_name or 'google' in self.model_provider:
            # Dial up temperature to avoid RECITATION errors
            self.temperature = 1.0

    async def infer_single_async(
        self, 
        history: list[dict[str, str]], 
        disable_system_prompt: bool = False,
        **kwargs
    ) -> str:
        """
        Given a dialogue history, return a single response.
        By implementing this method, the `infer_single` and `infer_batch` methods will automatically be available for use.
        You may pass additional arguments to the API call, for example `infer_single(..., temperature=0.05, is_valid=lambda s: len(s) > 5)`.

        :param history: The dialogue history, in OpenAI format.
        :type history: list[dict[str, str]]
        :return: The single response.
        :rtype: str
        """
        history = deepcopy(history)
        
        if not disable_system_prompt and self.system_prompt:
            history.insert(0, {"role": "system", "content": self.system_prompt})
        
        dump_file(
            "___inference_record.jsonl",
            history,
            write_mode="a",
            indent=None,
        )

        prompt = Prompt(
            messages=[
                ChatMessage(
                    content=message["content"], role=MessageRole(message["role"])
                )
                for message in history
            ]
        )

        if "temperature" not in kwargs:
            kwargs["temperature"] = self.temperature
        if "presence_penalty" not in kwargs:
            kwargs["presence_penalty"] = self.presence_penalty
        
        if "claude" in self.model_name and "presence_penalty" in kwargs:
            if float(kwargs["presence_penalty"]) != 0:
                warnings.warn("Claude API doesn't support presence penalty. Ignoring the penalty.")
            
            del kwargs["presence_penalty"]
        
        if self.model_provider != "auto":
            kwargs["force_provider"] = self.model_provider
        
        # Check the correctness of the dialogue structure
        for msg_idx in range(1, len(prompt.messages)):
            this_msg, last_msg = prompt.messages[msg_idx], prompt.messages[msg_idx-1]
            is_valid = (
                this_msg.role.value in ("user", "assistant") # No system prompt after the first entry
                and this_msg.role.value != last_msg.role.value # No two consecutive speeches from same person
            )
            if not is_valid:
                print(f"Malstructured query: {prompt}")
                raise ValueError("Malstructured LLM query.")

        response: list[LLMResponse] = await self.shared_api(
            model_id=self.model_name,
            prompt=prompt,
            print_prompt_and_response=False,
            **kwargs,
        )

        try:
            res = response[0].completion
        except:
            res = response[0][0].completion
        
        dump_file(
            "___inference_record.jsonl",
            res,
            write_mode="a",
            indent=None,
        )
        return res
