import abc
import base64
import io
import os
import time
from typing import Any

import google.generativeai as genai
import numpy as np
from PIL import Image
from google.generativeai import types
from google.generativeai.types import answer_types
from google.generativeai.types import generation_types
from openai import OpenAI

ERROR_CALLING_LLM = 'Error calling LLM'


def _array_to_jpeg_bytes(image: np.ndarray) -> bytes:
  """Converts a numpy array into a byte string for a JPEG image."""
  image = Image.fromarray(image)
  in_mem_file = io.BytesIO()
  image.save(in_mem_file, format='JPEG')
  # Reset file pointer to start
  in_mem_file.seek(0)
  img_bytes = in_mem_file.read()
  return img_bytes


class LlmWrapper(abc.ABC):
  """Abstract interface for (text only) LLM."""
  
  @abc.abstractmethod
  def predict(
    self,
    text_prompt: str,
  ) -> tuple[str, Any]:
    """Calling multimodal LLM with a prompt and a list of images.

    Args:
      text_prompt: Text prompt.

    Returns:
      Text output and raw output.
    """


class MultimodalLlmWrapper(abc.ABC):
  """Abstract interface for Multimodal LLM."""
  
  @abc.abstractmethod
  def predict_mm(
    self, text_prompt: str = None, images: list[np.ndarray] = []
  ) -> tuple[str, Any]:
    """Calling multimodal LLM with a prompt and a list of images.

    Args:
      text_prompt: Text prompt.
      images: List of images as numpy ndarray.

    Returns:
      Text output and raw output.
    """


SAFETY_SETTINGS_BLOCK_NONE = {
  types.HarmCategory.HARM_CATEGORY_HARASSMENT: (
    types.HarmBlockThreshold.BLOCK_NONE
  ),
  types.HarmCategory.HARM_CATEGORY_HATE_SPEECH: (
    types.HarmBlockThreshold.BLOCK_NONE
  ),
  types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: (
    types.HarmBlockThreshold.BLOCK_NONE
  ),
  types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: (
    types.HarmBlockThreshold.BLOCK_NONE
  ),
}

class OpenAIWrapper(LlmWrapper, MultimodalLlmWrapper):
  """OpenAI GPT4 wrapper.

  Attributes:
    openai_api_key: The class gets the OpenAI api key either explicitly, or
      through env variable in which case just leave this empty.
    max_retry: Max number of retries when some error happens.
    temperature: The temperature parameter in LLM to control result stability.
    model: GPT model to use based on if it is multimodal.
  """
  
  RETRY_WAITING_SECONDS = 1
  
  def __init__(
    self,
    model_name: str,
    max_retry: int = 100,
    temperature: float = 0.0,
  ):
    if 'OPENAI_API_KEY' not in os.environ:
      raise RuntimeError('OpenAI API key not set.')
    self.client = OpenAI(
      base_url=os.getenv('OPENAI_URL'),
      api_key=os.environ['OPENAI_API_KEY']
    )
    max_retry = 100
    self.max_retry = max_retry
    self.temperature = temperature
    self.model = model_name
  
  @classmethod
  def encode_image(cls, image: np.ndarray) -> str:
    return base64.b64encode(_array_to_jpeg_bytes(image)).decode('utf-8')
  
  def predict(
    self,
    text_prompt: str = None,
    messages: list[dict] = None
  ) -> tuple[str, Any]:
    return self.predict_mm(text_prompt, [], messages=messages)
  
  def predict_mm(
    self,
    text_prompt: str = None,
    images: list[np.ndarray] = [],
    messages: list[dict] = None
  ) -> tuple[str, Any]:
    if messages:
      final_messages = messages.copy()
    else:
      final_messages = [{
        'role': 'user',
        'content': [{'type': 'text', 'text': text_prompt}],
      }]
    if images:
      image_blocks = [{
        'type': 'image_url',
        'image_url': {
          'url': f'data:image/jpeg;base64,{self.encode_image(image)}',
        },
      } for image in images]
      for msg in reversed(final_messages):
        if msg.get("role") == "user":
          if isinstance(msg.get("content"), list):
            msg["content"].extend(image_blocks)
          else:
            msg["content"] = [{'type': 'text', 'text': msg["content"]}] + image_blocks
          break
    
    payload = {
      'model': self.model,
      'temperature': self.temperature,
      'messages': final_messages,
    }
    
    counter = self.max_retry
    wait_seconds = self.RETRY_WAITING_SECONDS
    while counter > 0:
      try:
        response = self.client.chat.completions.create(
          model=payload['model'],
          temperature=payload['temperature'],
          messages=payload['messages'],
        )
        if response:
          return (
            response.choices[0].message.content,
            response,
          )
        
        time.sleep(wait_seconds)
      except Exception as e:  # pylint: disable=broad-exception-caught
        # Want to catch all exceptions happened during LLM calls.
        time.sleep(wait_seconds)
        counter -= 1
        print('Error calling LLM, will retry soon...')
        print(e)
    return ERROR_CALLING_LLM, None
