import os
from typing import List

from dotenv import load_dotenv
import google.generativeai as palm
from google.api_core import exceptions
from tenacity import (
    retry,
    retry_if_not_exception_type,
    wait_random_exponential,
)

from ..message import Role, Message
from .base import BaseModel, build_prompt


GOOGLE_MODELS = [
    "chat-bison-001",
    "text-bison-001",
]
GOOGLE_DEFAULT = "text-bison-001"


# PALM API rejects non-English messages
class GoogleTextModel(BaseModel):
    """Interface for interacting with the PALM chat API.

    Call with a list of `Message` objects to generate a response. Streaming is not yet implemented.
    """

    supports_system_message = True

    def __init__(
        self,
        model: str,
        temperature: float = 0.0,
        top_k: int = 40,
        top_p: float = 0.95,
        max_tokens: int = 512,
        **kwargs,
    ):
        self.model = model
        self.temperature = temperature
        self.top_k = top_k
        self.top_p = top_p
        self.max_tokens = max_tokens
        load_dotenv()
        palm.configure(api_key=os.getenv("PALM_API_KEY", ""))
        self.safety_settings = [
            {
                "category": palm.types.HarmCategory.HARM_CATEGORY_DANGEROUS,
                "threshold": palm.types.HarmBlockThreshold.BLOCK_NONE,
            },
            {
                "category": palm.types.HarmCategory.HARM_CATEGORY_DEROGATORY,
                "threshold": palm.types.HarmBlockThreshold.BLOCK_NONE,
            },
            {
                "category": palm.types.HarmCategory.HARM_CATEGORY_MEDICAL,
                "threshold": palm.types.HarmBlockThreshold.BLOCK_NONE,
            },
            {
                "category": palm.types.HarmCategory.HARM_CATEGORY_SEXUAL,
                "threshold": palm.types.HarmBlockThreshold.BLOCK_NONE,
            },
            {
                "category": palm.types.HarmCategory.HARM_CATEGORY_TOXICITY,
                "threshold": palm.types.HarmBlockThreshold.BLOCK_NONE,
            },
            {
                "category": palm.types.HarmCategory.HARM_CATEGORY_UNSPECIFIED,
                "threshold": palm.types.HarmBlockThreshold.BLOCK_NONE,
            },
            {
                "category": palm.types.HarmCategory.HARM_CATEGORY_VIOLENCE,
                "threshold": palm.types.HarmBlockThreshold.BLOCK_NONE,
            },
        ]

    def __call__(self, messages: List[Message], api_key: str = None):
        if api_key is not None:
            palm.configure(api_key=api_key)

        prompt = build_prompt(messages) if len(messages) > 0 else messages[0].content
        try:
            completion = palm.generate_text(
                model="models/" + self.model,
                prompt=prompt,
                temperature=self.temperature,
                max_output_tokens=self.max_tokens,
                top_k=self.top_k,
                top_p=self.top_p,
                safety_settings=self.safety_settings,
            )
            result = completion.result or ""
        except exceptions.InvalidArgument as e:
            result = ""

        if api_key is not None:
            palm.configure(api_key=os.getenv("PALM_API_KEY", ""))

        return [result]  # wrap in trivial iterator


# Chat model is currently unusable because of overly strict safety filtering
class GoogleChatModel(BaseModel):
    """Interface for interacting with the PALM chat API.

    Call with a list of `Message` objects to generate a response. Streaming is not yet implemented.
    """

    supports_system_message = False

    def __init__(
        self,
        model: str,
        temperature: float = 0.0,
        top_k: int = 40,
        top_p: float = 0.95,
        **kwargs,
    ):
        self.model = model
        self.temperature = temperature
        self.top_k = top_k
        self.top_p = top_p
        load_dotenv()
        palm.configure(api_key=os.getenv("PALM_API_KEY", ""))

    def __call__(self, messages: List[Message], api_key: str = None):
        if api_key is not None:
            palm.configure(api_key=api_key)

        context = None
        if messages[0].role == Role.SYSTEM:
            if messages[0].content:
                context = messages[0].content
            messages = messages[1:]
        messages = [m.content for m in messages]
        reply = palm.chat(
            model="models/" + self.model,
            context=context,
            messages=messages,
            temperature=self.temperature,
            top_k=self.top_k,
            top_p=self.top_p,
        )
        result = reply.last or ""

        if api_key is not None:
            palm.configure(api_key=os.getenv("PALM_API_KEY", ""))

        return [result]  # wrap in trivial iterator


@retry(
    retry=retry_if_not_exception_type(
        (
            exceptions.BadRequest,
            exceptions.Unauthorized,
            exceptions.Forbidden,
            exceptions.NotFound,
        )
    ),
    wait=wait_random_exponential(min=1, max=10),
)
def google_call_with_retries(model, messages, api_key=None):
    return model(messages, api_key)
