import base64
from typing import Union, List, Optional
from together import Together


class TogetherAIWrapper:
    def __init__(
        self,
        model_name: str = "meta-llama/Llama-Vision-Free",
        stream: bool = False,
    ):
        """
        Initialize the TogetherAIWrapper with the Together client and model name.

        Parameters:
            model_name (str): The model to be used by Together AI.
            stream (bool): Whether to use streaming mode for responses.
        """
        self.model_name = model_name
        self.stream = stream
        self.client = Together()

    def encode_image(self, image_path: str) -> str:
        """
        Encode a local image file to a base64 string.

        Parameters:
            image_path (str): The local path to the image file.

        Returns:
            str: The base64-encoded string of the image.
        """
        with open(image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode("utf-8")

    def _prepare_image_payload(self, img: str) -> dict:
        """
        Prepare an image payload for the message. If `img` is a URL (starts with http),
        it will be used directly; otherwise, it is assumed to be a local file path and is
        encoded in base64.

        Parameters:
            img (str): Local file path or image URL.

        Returns:
            dict: The image message payload.
        """
        if img.startswith("http"):
            image_url = img
        else:
            # Encode local image file as base64 and create a data URL.
            base64_image = self.encode_image(img)
            image_url = f"data:image/jpeg;base64,{base64_image}"
        return {"type": "image_url", "image_url": {"url": image_url}}

    def get_prediction(
        self,
        image_path: Union[str, List[str]],
        prompt: str,
        passages: Optional[Union[str, List[Union[str, dict, tuple]]]] = None,
        passage_prompt: Optional[str] = None,
    ) -> Union[str, List[str]]:
        """
        Get a prediction from the model given image(s), a prompt, and optional passages.

        Parameters:
            image_path (str or List[str]): Local file path(s) or URL(s) for the image(s).
            prompt (str): The main prompt text.
            passages (optional): Additional context. Can be a string, list of strings, dicts, or tuples.
            passage_prompt (optional): An extra prompt appended after the passages.

        Returns:
            If streaming is disabled, returns the complete output string.
            If streaming is enabled, returns a generator yielding chunks of output.
        """
        # Create a base message structure.
        message = {"role": "user", "content": []}

        # Append passages if provided.
        if passages is not None:
            if not isinstance(passages, list):
                passages = [passages]
            for passage in passages:
                if isinstance(passage, dict):
                    # For dict passages, check for keys "image_path" and "caption".
                    if "image_path" in passage:
                        image_payload = self._prepare_image_payload(
                            passage["image_path"]
                        )
                        message["content"].append(image_payload)
                    if "caption" in passage:
                        message["content"].append(
                            {"type": "text", "text": passage["caption"]}
                        )
                elif isinstance(passage, tuple) and len(passage) == 2:
                    passage_image, passage_text = passage
                    if passage_image:
                        image_payload = self._prepare_image_payload(passage_image)
                        message["content"].append(image_payload)
                    message["content"].append({"type": "text", "text": passage_text})
                elif isinstance(passage, str):
                    message["content"].append({"type": "text", "text": passage})

        # Append extra passage prompt if provided.
        if passage_prompt:
            message["content"].append({"type": "text", "text": passage_prompt})

        # Ensure image_path is a list.
        if not isinstance(image_path, list):
            image_path = [image_path]
        # Add each image to the message.
        for img in image_path:
            image_payload = self._prepare_image_payload(img)
            message["content"].append(image_payload)

        # Append the main prompt text.
        message["content"].append({"type": "text", "text": prompt})

        # Construct the payload for the Together AI API.
        payload = {
            "model": self.model_name,
            "messages": [message],
            "stream": self.stream,
        }

        # Call the Together AI API.
        response = self.client.chat.completions.create(**payload)

        if self.stream:
            # If streaming, yield chunks as they arrive.
            def stream_generator():
                for chunk in response:
                    if chunk.choices and hasattr(chunk.choices[0].delta, "content"):
                        yield chunk.choices[0].delta.content or ""

            return stream_generator()
        else:
            # For non-streaming mode, aggregate the full response.
            output = ""
            if response.choices:
                # Assuming response.choices[0].message.content holds the full response text.
                output = response.choices[0].message.content
            return output


# Example usage:
if __name__ == "__main__":
    # Initialize the TogetherAIWrapper.
    together_wrapper = TogetherAIWrapper(stream=False)

    # Example: Using a local image file and a prompt.
    image_path = "/drl_nas1/ckddls1321/data/coco/val2014/COCO_val2014_000000192168.jpg"
    prompt = "What is the outfit this man is wearing called?"

    # Get and print the prediction.
    result = together_wrapper.get_prediction(image_path, prompt)
    print("Prediction:", result)
