import openai
import anthropic
import google.generativeai as genai
import PIL.Image
from constants import OPENAI_API_KEY, ANTHROPIC_API_KEY, GOOGLE_API_KEY, MAX_OUTPUT_TOKENS
import base64
import os
import io
from anthropic import AnthropicVertex

# Function to encode the image
def encode_image(image_path):
  with open(image_path, "rb") as image_file:
    return base64.b64encode(image_file.read()).decode('utf-8')

# New function to resize image
def resize_image(image_path, max_width=500):
    with PIL.Image.open(image_path) as img:
        original_width, original_height = img.size
        if original_width > max_width:
            ratio = max_width / original_width
            new_height = int(original_height * ratio)
            img = img.resize((max_width, new_height), PIL.Image.LANCZOS)
        
        buffered = io.BytesIO()
        img.save(buffered, format="PNG")
        return base64.b64encode(buffered.getvalue()).decode('utf-8')


class BaseLLM:
    def __init__(self, model_name, force_json=False):
        self.model_name = model_name
        self.api_key = ''
        self.force_json = force_json
        self.prompt_tokens = 0
        self.completion_tokens = 0
        self.total_tokens = 0
        # self.token_usage = 0
        self.token_usage = []

    def generate(self, prompt, image_path=None):
        raise NotImplementedError("Subclasses must implement this method")

class GPT4OLLM(BaseLLM):
    def __init__(self, model_name):
        super().__init__(model_name)
        self.api_key = OPENAI_API_KEY
        self.client = openai.OpenAI(api_key=self.api_key)

    def generate(self, messages):
        response = self.client.chat.completions.create(
            model=self.model_name,
            messages=messages,
        )
        # self.token_usage = response.usage.prompt_tokens + response.usage.completion_tokens
        # self.token_usage = response.usage.total_tokens  # https://platform.openai.com/docs/api-reference/chat/create?lang=python
        self.prompt_tokens = response.usage.prompt_tokens
        self.completion_tokens = response.usage.completion_tokens
        self.total_tokens = response.usage.total_tokens
        # self.token_usage = self.total_tokens
        self.token_usage = [self.prompt_tokens,self.completion_tokens,self.total_tokens]
        return response.choices[0].message.content, self.token_usage

class ClaudeLLM(BaseLLM):
    def __init__(self, model_name):
        super().__init__(model_name)
        self.api_key = ANTHROPIC_API_KEY
        self.client = anthropic.Anthropic(api_key=self.api_key)

    def generate(self, messages):

        response = self.client.messages.create(
            model=self.model_name,
            max_tokens=MAX_OUTPUT_TOKENS,
            messages=messages,
        )
        # self.token_usage = response.usage.input_tokens + response.usage.output_tokens  # https://docs.anthropic.com/en/api/messages
        self.prompt_tokens = response.usage.input_tokens
        self.completion_tokens = response.usage.output_tokens
        self.total_tokens = response.usage.input_tokens + response.usage.output_tokens
        # self.token_usage = self.total_tokens
        self.token_usage = [self.prompt_tokens,self.completion_tokens,self.total_tokens]
        return response.content[0].text, self.token_usage

# class ClaudeLLM(BaseLLM):
#     def __init__(self, model_name):
#         super().__init__(model_name)
#         self.location = "us-central1"
#         self.client = AnthropicVertex(region=self.location, project_id="vivid-kite-432310-b2")

#     def generate(self, messages):

#         response = self.client.messages.create(
#             model="claude-3-5-sonnet@20240620",
#             max_tokens=MAX_OUTPUT_TOKENS,
#             messages=messages,
#         )
#         # self.token_usage = response.usage.input_tokens + response.usage.output_tokens  # https://docs.anthropic.com/en/api/messages
#         self.prompt_tokens = response.usage.input_tokens
#         self.completion_tokens = response.usage.output_tokens
#         self.total_tokens = response.usage.input_tokens + response.usage.output_tokens
#         # self.token_usage = self.total_tokens
#         self.token_usage = [self.prompt_tokens,self.completion_tokens,self.total_tokens]
#         return response.content[0].text, self.token_usage

class GeminiLLM(BaseLLM):
    """
    Latest: gemini-1.5-pro-latest  # April 9, 2024
    Latest stable: gemini-1.5-pro  # Referenced stable version : gemini-1.5-pro-001
    Stable: gemini-1.5-pro-001   # May 24, 2024
    Experimental:
    gemini-1.5-pro-exp-0801
    gemini-1.5-pro-exp-0827
    """
    def __init__(self, model_name):
        super().__init__(model_name)
        self.api_key = GOOGLE_API_KEY
        genai.configure(api_key=self.api_key)
        # genai.configure(transport='grpc')
        os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890'
        # Suppress logging warnings
        os.environ["GRPC_VERBOSITY"] = "ERROR"
        os.environ["GLOG_minloglevel"] = "2"
        self.model = genai.GenerativeModel(self.model_name)

    def generate(self, messages):
        response = self.model.generate_content(messages)
        # self.token_usage = response.usage_metadata.prompt_token_count + response.usage_metadata.candidates_token_count
        # self.token_usage = response.usage_metadata.total_token_count # https://console.cloud.google.com/vertex-ai/generative/multimodal/create/text?project=pro-lattice-429507-m7
        self.prompt_tokens = response.usage_metadata.prompt_token_count
        self.completion_tokens = response.usage_metadata.candidates_token_count
        self.total_tokens = response.usage_metadata.total_token_count
        # self.token_usage = self.total_tokens
        self.token_usage = [self.prompt_tokens,self.completion_tokens,self.total_tokens]
        return response.text, self.token_usage

def make_api_request(model, prompt, image_path):
    messages = []
    if model == "GPT-4o":
        llm = GPT4OLLM("gpt-4o")
    elif model == "Claude-3.5 Sonnet":
        llm = ClaudeLLM("claude-3-5-sonnet-20240620")
    elif model == "Gemini-1.5 Pro":
        llm = GeminiLLM("gemini-1.5-pro")
    else:
        raise ValueError(f"Unknown model: {model}")
    
    if model == "GPT-4o":
        if image_path:
            base64_image = encode_image(image_path)
            content = [
                {"type": "text", "text": prompt},
                {"type": "image_url", "image_url": {
                    "url": f"data:image/png;base64,{base64_image}",
                    "detail": "low",
                }}
            ]
        else:
            content = prompt
        messages.append({"role": "user", "content": content})
    elif model == "Claude-3.5 Sonnet":
        if image_path:
            # base64_image = encode_image(image_path)
            base64_image = resize_image(image_path)
            content = [
                {
                    "type": "image",
                    "source": {
                        "type": "base64",
                        "media_type": "image/png",
                        "data": base64_image,
                    },
                },
                {"type": "text", "text": prompt}
            ]
        else:
            content = prompt
        messages.append({"role": "user", "content": content})
    elif model == "Gemini-1.5 Pro":
        if image_path:
            # image = genai.load_image(image_path)
            image = PIL.Image.open(image_path)
            content = [prompt, image]
        else:
            content = prompt
        messages.append({"role": "user", "parts": content})
    return llm.generate(messages)

class ChatSession:
    def __init__(self, model):
        self.model = model
        self.messages = []
        if model == "GPT-4o":
            self.llm = GPT4OLLM("gpt-4o")
        elif model == "Claude-3.5 Sonnet":
            self.llm = ClaudeLLM("claude-3-5-sonnet-20240620")
        elif model == "Gemini-1.5 Pro":
            self.llm = GeminiLLM("gemini-1.5-pro")
        else:
            raise ValueError(f"Unknown model: {model}")

    def add_message(self, role, content):
        self.messages.append({"role": role, "content": content})
    
    def add_part(self, role, content):
        self.messages.append({"role": role, "parts": content})
    
    def remove_last_message(self):
        if self.messages:
            self.messages.pop()

    def get_response(self, prompt, image_url=None):
        response = None
        tokens = 0
        if self.model == "GPT-4o":
            if image_url:
                base64_image = encode_image(image_url)
                content = [
                    {"type": "text", "text": prompt},
                    {"type": "image_url", "image_url": {
                        "url": f"data:image/jpeg;base64,{base64_image}",
                        "detail": "low",
                    }}
                ]
            else:
                content = prompt
            self.add_message("user", content)
            try:
                # print("messages len: "+str(len(self.messages)))
                response,tokens = self.llm.generate(self.messages)
                self.add_message("assistant", response)
            except Exception as e:
                print(f"Error: {str(e)}")
                self.remove_last_message()
        
        elif self.model == "Claude-3.5 Sonnet":
            if image_url:
                # base64_image = encode_image(image_url)
                base64_image = resize_image(image_url)
                content = [
                    {
                        "type": "image",
                        "source": {
                            "type": "base64",
                            "media_type": "image/png",
                            "data": base64_image,
                        },
                    },
                    {"type": "text", "text": prompt}
                ]
            else:
                content = prompt
            self.add_message("user", content)
            try:
                response,tokens = self.llm.generate(self.messages)
                self.add_message("assistant", response)
            except Exception as e:
                print(f"Error: {str(e)}")
                self.remove_last_message()
        elif self.model == "Gemini-1.5 Pro":
            if image_url:
                # image = genai.load_image(image_url)
                image = PIL.Image.open(image_url)
                content = [prompt, image]
            else:
                content = prompt
            self.add_part("user", content)
            try:
                response,tokens = self.llm.generate(self.messages)
                self.add_part("model", response)
            except Exception as e:
                print(f"Error: {str(e)}")
                self.remove_last_message()
                raise  # Re-raise the exception after handling
            
        if response is None:
            raise ValueError("Failed to generate a response")
        
        return response,tokens
