import cv2
import base64
import logging
import warnings
from .model_base import BlackBoxModelBase
from openai import OpenAI
from fastchat.conversation import get_conv_template
from openai import BadRequestError


class OpenaiModel(BlackBoxModelBase):
    def __init__(self, model_name: str, api_keys: str, generation_config=None):
        """
        Initializes the OpenAI model with necessary parameters.
        :param str model_name: The name of the model to use.
        :param str api_keys: API keys for accessing the OpenAI service.
        :param str template_name: The name of the conversation template, defaults to 'chatgpt'.
        :param dict generation_config: Configuration settings for generation, defaults to an empty dictionary.
        """
        if generation_config:
            self.client = OpenAI(api_key=api_keys, organization = generation_config['organization'])
        else:
            self.client = OpenAI(api_key=api_keys)
        self.model_name = model_name
        self.conversation = get_conv_template('chatgpt')
        self.generation_config = generation_config if generation_config is not None else {}
        self.seed = 42

    def set_system_message(self, system_message: str):
        """
        Sets a system message for the conversation.
        :param str system_message: The system message to set.
        """
        self.conversation.system_message = system_message

    def generate(self, messages, images, clear_old_history=True, **kwargs):
        """
        Generates a response based on messages that include conversation history.
        :param list[str]|str messages: A list of messages or a single message string.
        :param list[str]|str messages: A list of images or a single image path.
        :param bool clear_old_history: If True, clears the old conversation history before adding new messages.
        :return str: The response generated by the OpenAI model based on the conversation history.
        """
        if clear_old_history:
            self.input_list = []
        if isinstance(messages, str):
            messages = [messages]
            images = [images]

        self.inputs = []
        for index, (message, image) in enumerate(zip(messages, images)):
            self.input = {}
            self.input['role'] = 'user'
            self.input['content'] = []

            text_conv = {"type": "text", "text": message}
            self.input['content'].append(text_conv)

            if "http" in image:
                image_conv = {"type": "image_url", "image_url": {"url": image}}
            else:
                base64_image = self.encode_image(image)
                image_conv = {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{base64_image}"}}
            self.input['content'].append(image_conv)

            self.inputs.append(self.input)
        num_attempts = 0
        while num_attempts < 5:
            num_attempts += 1
            try:
                response = self.client.chat.completions.create(
                    model=self.model_name,
                    messages=self.inputs,
                    seed=self.seed,
                    temperature = 0, 
                    top_p = 1,
                    **kwargs,
                    **self.generation_config
                )
                return response.choices[0].message.content
            except BadRequestError as be:
                continue
            except Exception as e:
                print(f"OpenAI server offers this error: {e}")
                continue

    def batch_generate(self, conversations, batches, **kwargs):
        """
        Generates responses for multiple conversations in a batch.
        :param list[list[str]]|list[str] conversations: A list of conversations, each as a list of messages.
        :param list[list[str]]|list[str] batches: A list of batches, each as a list of images.
        :return list[str]: A list of responses for each conversation.
        """
        responses = []
        for conversation, image in zip(conversations, batches):
            if isinstance(conversation, str):
                warnings.warn('For batch generation based on several conversations, provide a list[str] for each conversation. '
                              'Using list[list[str]] will avoid this warning.')
            responses.append(self.generate(conversation, image, **kwargs))
        return responses
    
    def encode_image(self, image_path):
        with open(image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode('utf-8')
