from enum import Enum
import os
from typing import Optional

import xml.etree.ElementTree as ET


class Role(Enum):
    SYSTEM = "system"
    AI = "assistant"
    USER = "user"

    @classmethod
    def from_value(cls, value):
        for member in cls:
            if member.value == value:
                return member
        raise ValueError(f"No matching role for value: {value}")

class Message:
    content: str
    role: str
    tags: set
    short_content: Optional[str]

    def __init__(self, role: str, content: str, tags: set = None, short_content: str | None = None):
        self.role = role
        self.content = content
        self.tags = set(tags or {})
        self.short_content = short_content or ""

    def copy(self):
        return Message(self.role, self.content, self.tags, self.short_content)

    def short_version(self):
        return Message(self.role, self.short_content, self.tags, self.short_content)

    @classmethod
    def from_xml_element(cls, element):
        role = Role.from_value(element.get("role"))
        content = element.text.strip()
        return Message(role, content)

    def _header(self):
        name = self.__class__.__name__
        return f"{name.upper()} with tags ({', '.join(self.tags)}):\n"

    def __str__(self):
        role = self.role
        tags = "{" + ", ".join(str(t) for t in self.tags) + "}"
        return f"Message({role=}, tags={tags})\n------------\n{self.content}"

    def __repr__(self):
        return self.__str__()

    def _n_tokens(self, content: str, model: str):
        return num_tokens_from_string(content, model) + 4

    def n_tokens(self, model: str):
        return self._n_tokens(self.content or "", model)

    def dump(self):
        return {
            "role": self.role.value,
            "content": self.content or "",
        }

    def to_xml(self, parent: ET.Element = None):
        if parent is None:
            msg_element = ET.Element("message")
        else:
            msg_element = ET.SubElement(parent, "message")
        msg_element.set("role", self.role.value)
        msg_element.text = self.content

        return msg_element

    def to_dict(self):
        return {
            "role": self.role.value,
            "content": self.content or "",
            "tags": [tag.value for tag in self.tags],
            "short_content": self.short_content,
        }


def merge_messages(messages: list[Message]):
    """Merge consecutive messages of the same role into one message"""
    merged_messages = []
    new_message = None
    for message in messages:
        if new_message is None:
            new_message = message.copy()
        else:
            if message.role == new_message.role:
                new_message.content += "\n\n" + message.content
            else:
                merged_messages.append(new_message)
                new_message = message.copy()
    if new_message:
        merged_messages.append(new_message)

    return merged_messages
