import os
import traceback

from typing import (
    Any,
    Dict,
    Optional,
)

from dotenv import load_dotenv
from litellm import completion
from loguru import logger

from src.generator.openai_generator import OpenAIGeneratorBase
from src.schema import ReasoningGraph


os.environ['LITELLM_LOG'] = 'DEBUG'


load_dotenv()


class OllamaGenerator(OpenAIGeneratorBase):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs, base_url=os.getenv("OLLAMA_API_BASE"), api_key="ollama")

    def generate(self, prompt: str, json_mode: bool = False, json_schema: Optional[Dict[str, Any]] = None) -> str:
        response_format = None

        if json_mode:
            response_format = {"type": "json_object"}
        if json_schema is not None:
            response_format = ReasoningGraph.model_json_schema() # TODO: fix it later

        try:
            response = completion(
                model=f"ollama/{self.model_name}",
                messages=[
                    {"content": prompt, "role": "user"},
                ],
                api_base=os.getenv("OLLAMA_API_BASE"),
                temperature=self.temperature,
                stream=False,
                timeout=120,
                format=response_format,
            )
            return response.choices[0].message.content
        except:
            logger.error(
                f"Request errored even after backoff. Returning empty text for this prompt.\nPrompt: {prompt}\nTraceback: {traceback.format_exc()}"
            )
            return ""


class FireworksGenerator(OpenAIGeneratorBase):
    def __init__(self, *args, **kwargs):
        super().__init__(
            *args, **kwargs, base_url="https://api.fireworks.ai/inference/v1", api_key=os.getenv("FIREWORKS_API_KEY")
        )

    def generate(self, prompt: str, json_mode: bool = False, json_schema: Optional[Dict[str, Any]] = None) -> str:
        messages = [
            {"role": "user", "content": prompt},
        ]
        kwargs = {
            "model": self.model_name,
            "messages": messages,
            "temperature": self.temperature,
            "stream": False,
            "timeout": 120,
        }
        if json_mode:
            kwargs["response_format"] = {"type": "json_object"}
        if json_schema is not None:
            kwargs["response_format"] = json_schema
        try:
            response = self._chat_completion_with_backoff(**kwargs)
            return response.choices[0].message.content
        except:
            logger.error(
                f"Request errored even after backoff. Returning empty text for this prompt.\nPrompt: {prompt}\nTraceback: {traceback.format_exc()}"
            )
            return ""


class GeminiGenerator(OpenAIGeneratorBase):
    def __init__(self, *args, **kwargs):
        super().__init__(
            *args,
            **kwargs,
            base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
            api_key=os.getenv("GEMINI_API_KEY")
        )

    def generate(self, prompt: str, json_mode: bool = False, json_schema: Optional[Dict[str, Any]] = None) -> str:
        messages = [
            {"role": "user", "content": prompt},
        ]
        kwargs = {
            "model": self.model_name,
            "messages": messages,
            "temperature": self.temperature,
            "stream": False,
            "timeout": 120,
        }
        if json_mode:
            kwargs["response_format"] = {"type": "json_object"}
        if json_schema is not None:
            kwargs["response_format"] = json_schema
        try:
            response = self._chat_completion_with_backoff(**kwargs)
            response_content = response.choices[0].message.content
            if response_content is None:
                logger.error(
                    f"Response content is None. Returning empty text for this prompt.\nPrompt: {prompt}"
                )
                return ""
            return response_content
        except:
            logger.error(
                f"Request errored even after backoff. Returning empty text for this prompt.\nPrompt: {prompt}\nTraceback: {traceback.format_exc()}"
            )
            return ""