import os
import traceback

from typing import (
    Any,
    Dict,
    Generator,
    Optional,
)

from dotenv import load_dotenv
from loguru import logger
from openai import (
    APIConnectionError,
    InternalServerError,
    OpenAI,
    RateLimitError,
)
from tenacity import (
    retry,
    retry_if_exception_type,
    stop_after_attempt,
    wait_random_exponential,
)

from src.generator.base_generator import BaseGenerator
from src.utils.log import log_exception_in_retry


load_dotenv()


class OpenAIGeneratorBase(BaseGenerator):
    def __init__(self, model_name: str, temperature: float, api_key: str, base_url: Optional[str] = None) -> None:
        self.client = OpenAI(base_url=base_url, api_key=api_key)
        self.model_name = model_name
        self.temperature = temperature

    @retry(
        retry=retry_if_exception_type((RateLimitError, APIConnectionError, InternalServerError)),
        wait=wait_random_exponential(min=1, max=120),
        stop=stop_after_attempt(3),
        before_sleep=log_exception_in_retry,
    )
    def _chat_completion_with_backoff(self, **kwargs):
        return self.client.chat.completions.create(**kwargs)

    def generate(self, prompt: str, json_mode: bool = False, json_schema: Optional[Dict[str, Any]] = None) -> str:
        messages = [{"role": "user", "content": prompt}]
        kwargs = {
            "model": self.model_name,
            "messages": messages,
            "temperature": self.temperature,
            "stream": False,
            "timeout": 120,
        }
        if json_mode:
            kwargs["response_format"] = {"type": "json_object"}
        if json_schema is not None:
            kwargs["response_format"] = json_schema
        try:
            response = self._chat_completion_with_backoff(**kwargs)
            return response.choices[0].message.content
        except:
            logger.error(
                f"Request errored even after backoff. Returning empty text for this prompt.\nPrompt: {prompt}\nTraceback: {traceback.format_exc()}"
            )
            return ""

    def generate_async(self, prompt: str, json_mode: bool = False) -> Generator[str, None, None]:
        raise NotImplementedError()


class OpenAIGenerator(OpenAIGeneratorBase):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs, api_key=os.getenv("OPENAI_API_KEY"))
