from dataclasses import MISSING, dataclass, field, fields
from typing import Any, Dict, List, Optional, Union

import numpy as np
from PIL import Image

from utils.image_utils import any_to_bytes

# NOTE: avoiding pydantic/other to reduce overhead as prompt functions are called very often

SUPPORTED_ATOMIC_TYPES = {
    "text",
    "image",
    "video",
    "function_call",
    "computer_call",
    "reasoning",
    "function_output",
    "computer_output",
}

SUPPORTED_ROLES = {
    "user",
    "assistant",
    "system",
}


@dataclass
class ContentItem:
    """
    A content item is the most atomic level of input to a model.
    It contains the `type` of the data and the `data` itself.
    """

    type: str  # The type of the input.
    data: Any  # The data of the input.
    meta_data: Dict[str, Any] = field(default_factory=dict)  # Additional metadata about the content item.
    id: Optional[str] = None  # The id of the content item.

    def __init__(
        self,
        type: str,
        data: Any,
        meta_data: Dict[str, Any] = {},
        id: Optional[str] = None,
        raw_model_output: Optional[Any] = None,
    ):
        self.type = type
        self._validate()
        self.data = data
        self.meta_data = meta_data
        self.id = id
        self.payload_size = self.get_payload_size()

        # The raw model output for atomic items generated by a model.
        self.raw_model_output = raw_model_output
        # Not used yet. When unifying formats for function call, computer outputs, etc.
        # may be useful to store raw output here to re-send to the model
        # and only keep the data relevant for end use in `data`

    @staticmethod
    def validate_type(type: str) -> None:
        if type not in SUPPORTED_ATOMIC_TYPES:
            raise ValueError(f"Invalid content type: {type}")

    def _validate(self) -> None:
        ContentItem.validate_type(self.type)
        # TODO: add more

    def to_dict(self) -> Dict[str, Any]:
        return {
            "type": self.type,
            "data": self.data,
            "meta_data": self.meta_data,
            "id": self.id,
        }

    def __bool__(self) -> bool:
        return bool(self.data)

    def get_payload_size(self) -> int:
        try:
            if self.type == "text":
                return len(self.data.encode("utf-8"))
            elif self.type == "image":
                # Convert image data to bytes using PNG format.
                return len(any_to_bytes(self.data, format="PNG"))
            elif self.type == "function_call":
                if self.raw_model_output is not None:
                    return len(str(self.raw_model_output))
                else:
                    return 0
            elif self.type == "function_output":
                return len(str(self.data))

            elif self.type == "reasoning":
                if self.raw_model_output is not None:
                    return len(str(self.raw_model_output))
                else:
                    return 0
            else:
                return None
        except Exception:
            return None


# A `Content` is a list of atomic inputs to send to a model within a message.
Contents = List[ContentItem]


@dataclass
class Message:
    """
    A Message contains the data exchanged between entities for a SINGLE round of a conversation.
    It contains a role (e.g., "user", "assistant"), an optional name, and the data.
    """

    role: str
    contents: List[ContentItem]
    name: Optional[str] = field(default_factory=str)
    meta_data: Dict[str, Any] = field(default_factory=dict)
    payload_size: Optional[int] = None

    def __post_init__(self):
        # Calculate the payload size after the dataclass __init__ has been executed.
        self.payload_size = sum(content.payload_size for content in self.contents if content.payload_size is not None)

    def to_dict(self) -> Dict[str, Any]:
        return {
            "contents": [content.to_dict() for content in self.contents],
            "role": self.role,
            "name": self.name,
            "meta_data": self.meta_data,
        }

    def text(self) -> str:
        return "\n".join([content.data for content in self.contents if content.type == "text"])

    def images(self) -> List[Image.Image]:
        return [content.data for content in self.contents if content.type == "image"]

    def raw_data(self) -> List[Any]:
        return [content.data for content in self.contents]

    def __bool__(self) -> bool:
        return bool(self.contents)


# The final API input is a List[Message]
APIInput = List[Message]


@dataclass
class Cache:
    messages_to_provider: List[Any] = field(default_factory=list)
    gen_config: Optional[Any] = None
    api_responses: List[Any] = field(default_factory=list)
    model_messages: List[Message] = field(default_factory=list)

    def reset(self) -> None:
        # Reset all fields to their default values
        for f in fields(self):
            if f.default_factory is not MISSING:
                # Use the default_factory to get a new default value
                setattr(self, f.name, f.default_factory())
            else:
                setattr(self, f.name, f.default)


class NumRetriesExceeded(Exception):
    """
    Exception raised when the number of retries exceeds the maximum allowed.
    It can include information about the last observed exception to help with debugging.
    """

    def __init__(self, last_exception: Exception | None = None):
        self.last_exception = last_exception
        if last_exception:
            message = f"Maximum retry attempts exceeded. Last error: {last_exception}"
        else:
            message = "Maximum retry attempts exceeded."
        super().__init__(message)
