"""
We use the OpenAI Python library's API.
"""

import os

import openai
import tiktoken
import utils
from language_model import LM

logger = utils.get_logger(__name__)

ALL_MODELS = {
    "4o": {
        "id": "gpt-4o-2024-11-20",
        "context": 128_000,
        "pricing": {"input": 2.50, "output": 10.00},
    },
    "4o-mini": {
        "id": "gpt-4o-mini-2024-07-18",
        "context": 128_000,
        "pricing": {"input": 0.15, "output": 0.60},
    },
    "4t": {
        "id": "gpt-4-turbo-2024-04-09",
        "context": 128_000,
        "pricing": {"input": 10.00, "output": 30.00},
    },
    "35t": {
        "id": "gpt-3.5-turbo-1106",
        "context": 16_000,
        "pricing": {"input": 1.00, "output": 2.00},
    },
    "4": {
        "id": "gpt-4-0613",  # 2023
        "context": 8_000,
        "pricing": {"input": 30.00, "output": 60.00},
    },
    "4.1": {
        "id": "gpt-4.1-2025-04-14",
        "context": 1_000_000,
        "pricing": {"input": 2.00, "output": 8.00},
    },
    "4.1-mini": {
        "id": "gpt-4.1-mini-2025-04-14",
        "context": 1_000_000,
        "pricing": {"input": 0.40, "output": 1.60},
    },
    "4.1-nano": {
        "id": "gpt-4.1-nano-2025-04-14",
        "context": 1_000_000,
        "pricing": {"input": 0.10, "output": 0.40},
    },
    "o3": {
        "id": "o3-2025-04-16",
        "context": 200_000,
        "pricing": {"input": 2.00, "output": 8.00},
    },
    "o3-min": {
        "id": "o3-mini-2025-01-31",
        "context": 200_000,
        "pricing": {"input": 1.10, "output": 4.40},
    },
    "o4-mini": {
        "id": "o4-mini-2025-04-16",
        "context": 200_000,
        "pricing": {"input": 1.10, "output": 4.40},
    },
    # ChatGPT 4o
    "c4o": {
        "id": "chatgpt-4o-latest",
        "context": 128_000,
        "pricing": None,
    },
    # o1 skipped $600 per 1M output token for o1-pro!
    "5-chat": {
        "id": "gpt-5-chat-latest",
        "context": 128_000,
        "pricing": {"input": 1.25, "output": 10.00},
    },
    "5r": {
        "id": "gpt-5-2025-08-07",
        "context": 400_000,
        "pricing": {"input": 1.25, "output": 10.00},
    },
    "5-mini": {
        "id": "gpt-5-mini-2025-08-07",
        "context": 400_000,
        "pricing": {"input": 0.25, "output": 2.00},
    },
    "5-nano": {
        "id": "gpt-5-nano-2025-08-07",
        "context": 400_000,
        "pricing": {"input": 0.05, "output": 0.40},
    },
}


class OpenAILM(LM):
    """Language model that uses the OpenAI Python library's chat completion API."""

    @staticmethod
    def _resolve_meta_by_handle(model_handle: str) -> dict | None:
        return ALL_MODELS.get(model_handle)

    @staticmethod
    def _resolve_meta_by_id(model_id: str) -> dict | None:
        for meta in ALL_MODELS.values():
            if meta.get("id") == model_id:
                return meta
        return None

    def __init__(self, model_name: str, cache: bool = True):
        """
        Args:
            model: e.g., "gpt-3.5-turbo" or "gpt-4"
        """
        # Resolve the underlying model identifier used by the API
        meta = self._resolve_meta_by_handle(model_name)
        resolved_model = meta["id"] if meta else model_name

        # Pass along discovered context length to the base class for metadata
        super().__init__(
            model_name=model_name,
            cache=cache,
            context_length=(meta.get("context") if meta else None),
        )
        api_key = os.environ.get("OPENAI_API_KEY")
        if not api_key:
            raise ValueError("OPENAI_API_KEY environment variable not set.")

        # Set API key and create a client handle
        openai.api_key = api_key
        self.client = openai.OpenAI(api_key=api_key)
        self.model = model_name
        # Also store for convenience when logging or external inspection
        self.context_length = meta.get("context") if meta else None
        self._resolved_model_id = resolved_model
        self.total_tokens = {"input": 0, "output": 0}

    def _generate(
        self,
        prompt: str,
        system_message: str | None,
        seed_str: str,
        logprobs: int = 0,
        **kwargs,
    ) -> str | dict:
        """
        Generate responses for a list of prompts.

        Args:
            prompts: List of prompt strings
            system_prompt: Optional system-level prompt

        Returns:
            List of responses, one per prompt
        """

        # Log detailed model information
        if self.model not in ALL_MODELS:
            logger.warning(f"Model {self.model} not in ALL_MODELS, using as is")
        meta = self._resolve_meta_by_handle(self.model)
        actual_model = meta["id"] if meta else self.model
        logger.debug("=== OpenAI Model Call ===")
        logger.debug(f"Model: {self.model} -> {actual_model}")
        logger.debug(f"Additional parameters: {kwargs}")
        logger.debug(f"System prompt: {system_message}")
        logger.debug(f"Full prompt: {prompt}")

        messages = []
        if system_message:
            messages.append({"role": "system", "content": system_message})
        messages.append({"role": "user", "content": prompt})

        logger.debug(f"Messages to send: {messages}")

        try:
            completion_kwargs = dict(kwargs)

            if logprobs:
                completion_kwargs.update(
                    {
                        "logprobs": True,
                        "top_logprobs": logprobs,
                    }
                )

            logger.debug(
                f"Calling OpenAI API with model {actual_model} {completion_kwargs=} {messages=!r}"
            )

            completion = self.client.chat.completions.create(
                model=actual_model,
                messages=messages,
                **completion_kwargs,
            )

            # Track token usage
            if completion.usage:
                self.total_tokens["input"] += completion.usage.prompt_tokens
                self.total_tokens["output"] += completion.usage.completion_tokens
                logger.debug(
                    f"Token usage - Input: {completion.usage.prompt_tokens}, Output: {completion.usage.completion_tokens}"
                )

            # Extract response and (optionally) logprobs
            choice = completion.choices[0]
            response_text = choice.message.content

            if logprobs:
                # In v1 of the OpenAI SDK, logprobs live under choice.logprobs
                try:
                    logprobs_data = [
                        [token_logprob.dict() for token_logprob in token.top_logprobs]
                        for token in choice.logprobs.content
                    ]
                except Exception as e:
                    # Defensive: if structure changes or data missing, still return something sensible
                    logger.warning(f"Failed to parse logprobs data: {e}")
                    logprobs_data = None

            meta_for_pricing = meta or self._resolve_meta_by_id(actual_model)
            if (
                meta_for_pricing
                and meta_for_pricing.get("pricing")
                and (self.total_tokens["input"] > 0 or self.total_tokens["output"] > 0)
            ):
                pricing = meta_for_pricing["pricing"]
                input_cost = (self.total_tokens["input"] / 1_000_000) * pricing["input"]
                output_cost = (self.total_tokens["output"] / 1_000_000) * pricing["output"]
                total_cost = input_cost + output_cost
                logger.info(
                    self.model
                    + " API Cost - Input tokens: %s, Output tokens: %s, Total cost: $%.4f",
                    f"{self.total_tokens['input']:,}",
                    f"{self.total_tokens['output']:,}",
                    total_cost,
                )

            if logprobs:
                return {
                    "response": response_text,
                    "logprobs": logprobs_data,
                }

            return response_text

        except Exception as e:
            logger.error(f"Error from OpenAI API: {e} on ")
            raise

    def _generate_with_details(
        self,
        prompt: str,
        system_message: str | None,
        seed_str: str,
        **kwargs,
    ) -> tuple[object, dict]:
        # Resolve actual model id for API call
        if self.model not in ALL_MODELS:
            logger.warning(f"Model {self.model} not in ALL_MODELS, using as is")
        meta = self._resolve_meta_by_handle(self.model)
        actual_model = meta["id"] if meta else self.model

        messages = []
        if system_message:
            messages.append({"role": "system", "content": system_message})
        messages.append({"role": "user", "content": prompt})

        completion = self.client.chat.completions.create(
            model=actual_model,
            messages=messages,
            **kwargs,
        )

        # Extract primary text
        choice = completion.choices[0]
        response_text = choice.message.content if choice.message and choice.message.content else ""

        # Track token usage (cumulative, for cost accounting)
        try:
            if completion.usage:
                self.total_tokens["input"] += getattr(completion.usage, "prompt_tokens", 0) or 0
                self.total_tokens["output"] += (
                    getattr(completion.usage, "completion_tokens", 0) or 0
                )
        except Exception:
            pass

        # Build legacy-compatible result
        result: object
        if kwargs.get("logprobs"):
            try:
                logprobs_data = [
                    [token_logprob.dict() for token_logprob in token.top_logprobs]
                    for token in choice.logprobs.content
                ]
            except Exception:
                logprobs_data = None
            result = {"response": response_text, "logprobs": logprobs_data}
        else:
            result = response_text

        # Make best-effort to expose full provider response as a dict
        details_dict: dict
        try:
            details_dict = completion.to_dict()
        except Exception:
            try:
                details_dict = completion.model_dump()  # type: ignore[attr-defined]
            except Exception:
                try:
                    details_dict = completion.__dict__
                except Exception:
                    details_dict = {}

        return (result, details_dict)

    def truncate_to_token_len(self, text: str, max_tokens: int) -> str:
        """Truncate text to at most ``max_tokens`` tokens using tiktoken.

        Chooses an encoding based on the resolved model id
        """
        meta = self._resolve_meta_by_handle(self.model)
        resolved_model = meta["id"] if meta else self.model

        # Heuristic mapping for OpenAI model families to encodings
        if (
            "gpt-4o" in resolved_model
            or "gpt-4.1" in resolved_model
            or resolved_model.startswith("chatgpt-4o")
            or resolved_model.startswith("o3")
            or resolved_model.startswith("o4")
        ):
            preferred_encoding = "o200k_base"
        elif resolved_model.startswith("gpt-3.5"):
            preferred_encoding = "cl100k_base"
        else:
            preferred_encoding = "cl100k_base"

        encoding = None
        try:
            encoding = tiktoken.get_encoding(preferred_encoding)
        except Exception:
            try:
                encoding = tiktoken.encoding_for_model(resolved_model)
            except Exception:
                try:
                    encoding = tiktoken.get_encoding("cl100k_base")
                except Exception:
                    encoding = None

        if encoding is None:
            # Last-resort approximation: whitespace tokenization
            tokens = text.split()
            if len(tokens) <= max_tokens:
                return text
            return " ".join(tokens[:max_tokens])

        token_ids = encoding.encode(text)
        if len(token_ids) <= max_tokens:
            return text
        # print(f"Truncating {self.model_name=} from {len(token_ids)} to {max_tokens}")
        return encoding.decode(token_ids[:max_tokens])

    def get_total_cost(self) -> float:
        """Get the total accumulated cost for all API calls made by this instance."""
        meta = self._resolve_meta_by_handle(self.model)
        actual_model = meta["id"] if meta else self.model
        meta_for_pricing = meta or self._resolve_meta_by_id(actual_model)
        if not meta_for_pricing or not meta_for_pricing.get("pricing"):
            return 0.0

        pricing = meta_for_pricing["pricing"]
        input_cost = (self.total_tokens["input"] / 1_000_000) * pricing["input"]
        output_cost = (self.total_tokens["output"] / 1_000_000) * pricing["output"]
        return input_cost + output_cost

    def generate_with_logprobs(
        self, prompts: list[str], top_logprobs: int = 5, system_prompt: str | None = None, **kwargs
    ) -> list[dict]:
        """
        Generate responses for a list of prompts, returning both the response text and
        top log probabilities of the tokens.

        Args:
            prompts: List of prompt strings
            top_logprobs: How many of the most likely tokens to return
            system_prompt: Optional system-level prompt

        Returns:
            A list of dictionaries. Each dictionary contains:
                {
                    "response": <generated_text>,
                    "logprobs": <log probability data returned by the API>
                }
        """
        if not prompts:
            return []

        results = []
        for prompt in prompts:
            messages = []
            if system_prompt:
                messages.append({"role": "system", "content": system_prompt})
            messages.append({"role": "user", "content": prompt})

            try:
                # We must enable logprobs and specify top_logprobs
                meta = self._resolve_meta_by_handle(self.model)
                actual_model = meta["id"] if meta else self.model
                completion = self.client.chat.completions.create(
                    model=actual_model,
                    messages=messages,
                    logprobs=True,
                    top_logprobs=top_logprobs,
                    **kwargs,
                )
                choice = completion.choices[0]
                response_text = choice.message.content
                # The logprobs should be nested under choice.message["logprobs"]
                # choice.logprobs.content[0].top_logprobs[0].dict()
                logprobs_data = [
                    [x.dict() for x in c.top_logprobs] for c in choice.logprobs.content
                ]

                results.append({"response": response_text, "logprobs": logprobs_data})

            except Exception as e:
                logger.error(f"Error from OpenAI for prompt {prompt[:100]}...: {e}")
                results.append({"response": "-1", "logprobs": None})

        return results
