import os
from typing import List

from dotenv import load_dotenv
import anthropic
from tenacity import (
    retry,
    retry_if_not_exception_type,
    wait_random_exponential,
)

from ..message import Role, Message
from .base import BaseModel

ANTHROPIC_MODELS = [
    "claude-instant-v1.0",
    "claude-instant-v1.1",
    "claude-instant-v1.2",
    "claude-v1.0",
    "claude-v1.1",
    "claude-v1.2",
    "claude-v1.3",
    "claude-2",
]
ANTHROPIC_DEFAULT = "claude-2"

_CLIENT = None


class Response:
    """Response wrapper class for Anthropic API response object.

    Implements the iterator interface to enable simple iteration over streamed response chunks, such that
    `"".join(response)` returns the full completion.
    """

    def __init__(self, response, stream=False):
        self.response = response
        self.stream = stream
        self.complete = False
        self.first_chunk = True
        if self.stream:
            self.response_iter = iter(self.response)
        else:
            self.response = self.response.completion[1:]  # strip out leading space

    def __iter__(self):
        return self

    def __next__(self):
        if self.complete:
            raise StopIteration

        if not self.stream:
            self.complete = True
            return self.response

        try:
            chunk = next(self.response_iter)
            delta = chunk.completion
            if self.first_chunk:
                self.first_chunk = False
                delta = delta[1:]  # strip out leading space
            return delta
        except StopIteration as e:
            self.complete = True
            raise e


class AnthropicModel(BaseModel):
    """Interface for interacting with the Anthropic API.

    Call with a list of `Message` objects to generate a response.
    """

    supports_system_message = False

    def __init__(
        self,
        model: str,
        stream: bool = False,
        temperature: float = 0.0,
        top_k: int = -1,
        top_p: float = -1,
        max_tokens: int = 512,
        **kwargs,
    ):
        self.model = model
        self.stream = stream
        self.temperature = temperature
        self.top_k = top_k
        self.top_p = top_p
        self.max_tokens = max_tokens
        load_dotenv()
        global _CLIENT
        _CLIENT = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY", ""))

    def _make_anthropic_prompt(self, messages: List[Message]):
        texts = []
        for m in messages:
            if m.role == Role.USER:
                texts.append(f"{anthropic.HUMAN_PROMPT} {m.content}")
            elif m.role == Role.ASSISTANT:
                texts.append(f"{anthropic.AI_PROMPT} {m.content}")
        return "".join(texts) + anthropic.AI_PROMPT

    def __call__(self, messages: List[Message], api_key: str = None):
        global _CLIENT

        if api_key is not None:
            _CLIENT = anthropic.Anthropic(api_key=api_key)

        prompt = self._make_anthropic_prompt(messages)
        request = {
            "prompt": prompt,
            "model": self.model,
            "max_tokens_to_sample": self.max_tokens,
            "temperature": self.temperature,
            "top_k": self.top_k,
            "top_p": self.top_p,
            "stream": self.stream,
        }

        response = _CLIENT.completions.create(**request)

        if api_key is not None:
            _CLIENT = anthropic.Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY", ""))

        return Response(response, stream=self.stream)


@retry(
    retry=retry_if_not_exception_type(
        (
            anthropic.BadRequestError,
            anthropic.AuthenticationError,
            anthropic.PermissionDeniedError,
            anthropic.NotFoundError,
        )
    ),
    wait=wait_random_exponential(min=1, max=10),
)
def anthropic_call_with_retries(model, messages, api_key=None):
    return model(messages, api_key)
