# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========

# Enables postponed evaluation of annotations (for string-based type hints)
from __future__ import annotations

import base64
import os
import uuid
from io import BytesIO
from typing import List, Optional

from openai import OpenAI
from PIL import Image

from camel.logger import get_logger
from camel.toolkits import FunctionTool
from camel.toolkits.base import BaseToolkit
from camel.utils import MCPServer

logger = get_logger(__name__)


@MCPServer()
class DalleToolkit(BaseToolkit):
    r"""A class representing a toolkit for image generation using OpenAI's
    DALL-E model.
    """

    def __init__(
        self,
        timeout: Optional[float] = None,
    ):
        r"""Initializes a new instance of the DalleToolkit class.

        Args:
            timeout (Optional[float]): The timeout value for API requests
                in seconds. If None, no timeout is applied.
                (default: :obj:`None`)
        """
        super().__init__(timeout=timeout)

    def base64_to_image(self, base64_string: str) -> Optional[Image.Image]:
        r"""Converts a base64 encoded string into a PIL Image object.

        Args:
            base64_string (str): The base64 encoded string of the image.

        Returns:
            Optional[Image.Image]: The PIL Image object or None if conversion
                fails.
        """
        try:
            # Decode the base64 string to get the image data
            image_data = base64.b64decode(base64_string)
            # Create a memory buffer for the image data
            image_buffer = BytesIO(image_data)
            # Open the image using the PIL library
            image = Image.open(image_buffer)
            return image
        except Exception as e:
            error_msg = (
                f"An error occurred while converting base64 to image: {e}"
            )
            logger.error(error_msg)
            return None

    def image_path_to_base64(self, image_path: str) -> str:
        r"""Converts the file path of an image to a Base64 encoded string.

        Args:
            image_path (str): The path to the image file.

        Returns:
            str: A Base64 encoded string representing the content of the image
                file.
        """
        try:
            with open(image_path, "rb") as image_file:
                return base64.b64encode(image_file.read()).decode('utf-8')
        except Exception as e:
            error_msg = (
                f"An error occurred while converting image path to base64: {e}"
            )
            logger.error(error_msg)
            return error_msg

    def image_to_base64(self, image: Image.Image) -> str:
        r"""Converts an image into a base64-encoded string.

        This function takes an image object as input, encodes the image into a
        PNG format base64 string, and returns it.
        If the encoding process encounters an error, it prints the error
        message and returns None.

        Args:
            image: The image object to be encoded, supports any image format
                that can be saved in PNG format.

        Returns:
            str: A base64-encoded string of the image.
        """
        try:
            with BytesIO() as buffered_image:
                image.save(buffered_image, format="PNG")
                buffered_image.seek(0)
                image_bytes = buffered_image.read()
                base64_str = base64.b64encode(image_bytes).decode('utf-8')
                return base64_str
        except Exception as e:
            error_msg = f"An error occurred: {e}"
            logger.error(error_msg)
            return error_msg

    def get_dalle_img(self, prompt: str, image_dir: str = "img") -> str:
        r"""Generate an image using OpenAI's DALL-E model.
            The generated image is saved to the specified directory.

        Args:
            prompt (str): The text prompt based on which the image is
                generated.
            image_dir (str): The directory to save the generated image.
                Defaults to 'img'.

        Returns:
            str: The path to the saved image.
        """

        dalle_client = OpenAI()
        response = dalle_client.images.generate(
            model="dall-e-3",
            prompt=prompt,
            size="1024x1792",
            quality="standard",
            n=1,  # NOTE: now dall-e-3 only supports n=1
            response_format="b64_json",
        )
        if response.data is None or len(response.data) == 0:
            error_msg = "No image data returned from DALL-E API."
            logger.error(error_msg)
            return error_msg
        image_b64 = response.data[0].b64_json
        image = self.base64_to_image(image_b64)  # type: ignore[arg-type]

        if image is None:
            error_msg = "Failed to convert base64 string to image."
            logger.error(error_msg)
            return error_msg

        os.makedirs(image_dir, exist_ok=True)
        image_path = os.path.join(image_dir, f"{uuid.uuid4()}.png")
        image.save(image_path)

        return image_path

    def get_tools(self) -> List[FunctionTool]:
        r"""Returns a list of FunctionTool objects representing the
        functions in the toolkit.

        Returns:
            List[FunctionTool]: A list of FunctionTool objects
                representing the functions in the toolkit.
        """
        return [FunctionTool(self.get_dalle_img)]
