import time
import os
from typing import Any, Dict, List

import openai
import logging

from openai import OpenAI

from llm_interface.large_language_model import LargeLanguageModel
from llm_interface.llm_response import LLMResponse

logger = logging.getLogger("global_logger")

class ChatGPT(LargeLanguageModel):
    def __init__(self, model_name: str) -> None:
        self._model_name = model_name
        self.OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
        openai.api_key = self.OPENAI_API_KEY

    def get_id(self) -> str:
        return f"chatgpt_{self._model_name}"

    def _sample_completions(
            self,
            prompt: str,
            temperature: float,
            num_completions: int = 1) -> List[LLMResponse]:
        """
        Note that sys and user prompt are assumed to be separated by a newline.
        """

        response = None
        for _ in range(6):
            try:
                client = OpenAI(api_key=self.OPENAI_API_KEY)
                response = client.chat.completions.create(
                    model=self._model_name,
                    messages=[
                        {"role": "user", "content": prompt},
                    ],
                    temperature=temperature,
                    n=num_completions)
                # Successfully queried, so break.
                break
            except (openai.RateLimitError,
                    openai.APIConnectionError, openai.APIError):
                # Wait for 60 seconds if this limit is reached. Hopefully rare.
                time.sleep(1)

        if response is None:
            raise RuntimeError("Failed to query OpenAI API.")

        assert len(response.choices) == num_completions
        return [
            self._raw_to_llm_response(r, prompt, temperature, num_completions)
            for r in response.choices
        ]

    @staticmethod
    def _raw_to_llm_response(raw_response: Dict[str, Any],
                             prompt: str,
                             temperature: float,
                             num_completions: int) -> LLMResponse:
        text = raw_response.message.content
        prompt_info = {
            "temperature": temperature,
            "num_completions": num_completions
        }
        return LLMResponse(prompt,
                           text,
                           prompt_info=prompt_info,
                           other_info=raw_response.copy())