import base64
import openai
import anthropic
from anthropic import Anthropic, HUMAN_PROMPT, AI_PROMPT
from tqdm import tqdm
from typing import List
import google.generativeai as genai
import time
from mistralai.client import MistralClient
from mistralai.models.chat_completion import ChatMessage

import re

def api_models_map(model_name_or_path=None, token=None, **kwargs):
    if 'gpt-' in model_name_or_path:
        if 'vision' in model_name_or_path:
            return GPTV(model_name_or_path, token)
        else:
            return GPT(model_name_or_path, token)
    elif 'claude-' in model_name_or_path:
        return Claude(model_name_or_path, token)
    elif 'gemini-' in model_name_or_path:
        return Gemini(model_name_or_path, token)
    elif re.match(r'mistral-(tiny|small|medium|large)$', model_name_or_path):
        return Mistral(model_name_or_path, token)
    return None

class GPT():
    API_RETRY_SLEEP = 10
    API_ERROR_OUTPUT = "$ERROR$"
    API_QUERY_SLEEP = 0.5
    API_MAX_RETRY = 5
    API_TIMEOUT = 60

    def __init__(self, model_name, api_key):
        self.model_name = model_name
        openai.api_key = api_key


    def _generate(self, prompt: str, 
                max_new_tokens: int, 
                temperature: float,
                top_p: float):
        output = self.API_ERROR_OUTPUT
        for _ in range(self.API_MAX_RETRY):
            try:
                response = openai.ChatCompletion.create(
                            model = self.model_name,
                            messages=[{"role": "user", "content": prompt}],
                            max_tokens = max_new_tokens,
                            temperature = temperature,
                            top_p = top_p,
                            request_timeout = self.API_TIMEOUT,
                            )
                output = response["choices"][0]["message"]["content"]
                break
            except openai.error.OpenAIError as e:
                print(type(e), e)
                time.sleep(self.API_RETRY_SLEEP)
        
            time.sleep(self.API_QUERY_SLEEP)
        return output 
    
    def generate(self, 
                prompts: List[str],
                max_new_tokens: int, 
                temperature: float,
                top_p: float = 1.0,
                use_tqdm: bool=False,
                **kwargs):
        
        if use_tqdm:
            prompts = tqdm(prompts)
        return [self._generate(prompt, max_new_tokens, temperature, top_p) for prompt in prompts]

class GPTV():
    API_RETRY_SLEEP = 10
    API_ERROR_OUTPUT = "$ERROR$"
    API_QUERY_SLEEP = 0.5
    API_MAX_RETRY = 5
    API_TIMEOUT = 20

    def __init__(self, model_name, api_key):
        self.model_name = model_name
        openai.api_key = api_key


    def _generate(self, prompt: str,
                image_path: str,
                max_new_tokens: int, 
                temperature: float,
                top_p: float):
        output = self.API_ERROR_OUTPUT


        with open(image_path, "rb") as image_file:
            image_s = base64.b64encode(image_file.read()).decode('utf-8')
            image_url = {"url": f"data:image/jpeg;base64,{image_s}"}


        for _ in range(self.API_MAX_RETRY):
            try:
                response = openai.ChatCompletion.create(
                            model = self.model_name,
                            messages=[
                                        {"role": "user", "content":
                                            [
                                                {"type": "text", "text": prompt},
                                                {"type": "image_url","image_url": image_url}
                                            ]
                                        }
                                    ],
                            max_tokens = max_new_tokens,
                            temperature = temperature,
                            top_p = top_p,
                            request_timeout = self.API_TIMEOUT,
                            )
                output = response["choices"][0]["message"]["content"]
                break
            except openai.error.InvalidRequestError as e:
                if "Your input image may contain content that is not allowed by our safety system" in str(e):
                    output = "I'm sorry, I can't assist with that request."
                    break 
            except openai.error.OpenAIError as e:
                print(type(e), e)
                time.sleep(self.API_RETRY_SLEEP)
        
            time.sleep(self.API_QUERY_SLEEP)
        return output 
    
    def generate(self, 
                prompts: List[str],
                images: List[str],
                max_new_tokens: int, 
                temperature: float,
                top_p: float = 1.0,
                use_tqdm: bool=False,
                **kwargs):
        if use_tqdm:
            prompts = tqdm(prompts)

        return [self._generate(prompt, img, max_new_tokens, temperature, top_p) for prompt, img in zip(prompts, images)]

class Claude():
    API_RETRY_SLEEP = 10
    API_ERROR_OUTPUT = "$ERROR$"
    API_QUERY_SLEEP = 1
    API_MAX_RETRY = 5
    API_TIMEOUT = 20
    default_output = "I'm sorry, but I cannot assist with that request."

    def __init__(self, model_name, api_key) -> None:
        self.model_name = model_name
        self.API_KEY = api_key

        self.model= Anthropic(
            api_key=self.API_KEY,
            )

    def _generate(self, prompt: str, 
                max_new_tokens: int, 
                temperature: float,
                top_p: float):

        output = self.API_ERROR_OUTPUT
        for _ in range(self.API_MAX_RETRY):
            try:
                completion = self.model.completions.create(
                    model=self.model_name,
                    max_tokens_to_sample=max_new_tokens,
                    prompt=f"{HUMAN_PROMPT} {prompt}{AI_PROMPT}",
                    temperature=temperature,
                    top_p=top_p
                )
                output = completion.completion
                break
            except anthropic.BadRequestError as e:
                # as of Jan 2023, this show the output has been blocked
                if "Output blocked by content filtering policy" in str(e):
                    output = self.default_output
                    break
            except anthropic.APIError as e:
                print(type(e), e)
                time.sleep(self.API_RETRY_SLEEP)
            time.sleep(self.API_QUERY_SLEEP)
        return output
    
    def generate(self, 
                prompts: List[str],
                max_new_tokens: int, 
                temperature: float,
                top_p: float = 1.0,  
                use_tqdm: bool=False, 
                **kwargs):

        if use_tqdm:
            prompts = tqdm(prompts)
        return [self._generate(prompt, max_new_tokens, temperature, top_p) for prompt in prompts]
        
class Gemini():
    API_RETRY_SLEEP = 10
    API_ERROR_OUTPUT = "$ERROR$"
    API_QUERY_SLEEP = 1
    API_MAX_RETRY = 5
    API_TIMEOUT = 20
    default_output = "I'm sorry, but I cannot assist with that request."

    def __init__(self, model_name, token) -> None:
        self.model_name = model_name
        genai.configure(api_key=token)

        self.model = genai.GenerativeModel(model_name)

    def _generate(self, prompt: str, 
                max_n_tokens: int, 
                temperature: float,
                top_p: float):

        output = self.API_ERROR_OUTPUT
        generation_config=genai.types.GenerationConfig(
                max_output_tokens=max_n_tokens,
                temperature=temperature,
                top_p=top_p)
        chat = self.model.start_chat(history=[])
        
        for _ in range(self.API_MAX_RETRY):
            try:
                completion = chat.send_message(prompt, generation_config=generation_config)
                output = completion.text
                break
            except (genai.types.BlockedPromptException, genai.types.StopCandidateException):
                # Prompt was blocked for safety reasons
                output = self.default_output
                break
            except Exception as e:
                print(type(e), e)
                # as of Jan 2023, this show the output has been filtering by the API
                if "contents.parts must not be empty." in str(e):
                    output = self.default_output
                    break
                time.sleep(self.API_RETRY_SLEEP)
            time.sleep(1)
        return output
    
    def generate(self, 
                prompts: List[str],
                max_new_tokens: int, 
                temperature: float,
                top_p: float = 1.0,  
                use_tqdm: bool=False, 
                **kwargs):
        if use_tqdm:
            prompts = tqdm(prompts)
        return [self._generate(prompt, max_new_tokens, temperature, top_p) for prompt in prompts]
    

class Mistral():
    API_RETRY_SLEEP = 10
    API_ERROR_OUTPUT = "$ERROR$"
    API_QUERY_SLEEP = 0.5
    API_MAX_RETRY = 5

    def __init__(self, model_name, token):    
        self.model_name = model_name
        self.client = MistralClient(api_key=token)

    def _generate(self, prompt: str, 
                max_new_tokens: int, 
                temperature: float,
                top_p: float):
        
        output = self.API_ERROR_OUTPUT
        messages = [
            ChatMessage(role="user", content=prompt)
        ]
        for _ in range(self.API_MAX_RETRY):
            try:
                chat_response = self.client.chat(
                    model=self.model,
                    temperature=temperature,
                    max_tokens=max_new_tokens,
                    messages=messages,
                )
                output = chat_response.choices[0].message.content
                break
            except Exception as e:
                print(type(e), e)
                time.sleep(self.API_RETRY_SLEEP)
            time.sleep(self.API_QUERY_SLEEP)
        return output 
    
    def generate(self, 
                prompts: List[str],
                max_new_tokens: int, 
                temperature: float,
                top_p: float = 1.0,
                use_tqdm: bool=False,
                **kwargs):
        
        if use_tqdm:
            prompts = tqdm(prompts)
        return [self._generate(prompt, max_new_tokens, temperature, top_p) for prompt in prompts]