import os
import traceback

from typing import (
    Any,
    Dict,
    Optional,
)

from dotenv import load_dotenv
from litellm import completion
from loguru import logger

from src.generator.openai_generator import BaseOpenAIGenerator


os.environ['LITELLM_LOG'] = 'DEBUG'


load_dotenv()


class OllamaGenerator(BaseOpenAIGenerator):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs, base_url=os.getenv("OLLAMA_API_BASE"), api_key="ollama")

    def generate(self, prompt: str, json_mode: bool = False, json_schema: Optional[Dict[str, Any]] = None) -> str:
        response_format = None

        if json_mode:
            response_format = {"type": "json_object"}

        try:
            response = completion(
                model=f"ollama/{self.model_name}",
                messages=[
                    {"content": prompt, "role": "user"},
                ],
                api_base=os.getenv("OLLAMA_API_BASE"),
                temperature=self.temperature,
                max_tokens=1000,
                stream=False,
                timeout=120,
                format=response_format,
            )
            return response.choices[0].message.content
        except:
            logger.error(
                f"Request errored even after backoff. Returning empty text for this prompt.\nPrompt: {prompt}\nTraceback: {traceback.format_exc()}"
            )
            return ""


class FireworksGenerator(BaseOpenAIGenerator):
    def __init__(self, *args, **kwargs):
        super().__init__(
            *args, **kwargs, base_url="https://api.fireworks.ai/inference/v1", api_key=os.getenv("FIREWORKS_API_KEY")
        )


class GeminiGenerator(BaseOpenAIGenerator):
    def __init__(self, *args, **kwargs):
        super().__init__(
            *args,
            **kwargs,
            base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
            api_key=os.getenv("GEMINI_API_KEY")
        )
