from __future__ import annotations

from enum import Enum
from typing import TypedDict, Union

import tiktoken

HarmonyMessage = TypedDict(
    "HarmonyMessage",
    {
        "role": str,
        "content": str,
    },
)


class PromptFormat(str, Enum):
    """
    Different ways of formatting the components of a prompt into the format accepted
    by the relevant API server endpoint.
    """

    NONE = "none"
    """Suitable for use with models that don't use special tokens for instructions."""
    INSTRUCTION_FOLLOWING = "instruction_following"
    """Suitable for IF models that use <|endofprompt|>."""
    HARMONY_V4 = "harmony_v4"
    """
    Suitable for Harmony models that use a structured turn-taking role+content format.
    Generates a list of HarmonyMessage dicts that can be sent to the /chat/completions
    endpoint.
    """

    @classmethod
    def from_string(cls, s: str) -> PromptFormat:
        for prompt_format in cls:
            if prompt_format.value == s:
                return prompt_format
        raise ValueError(f"{s} is not a valid PromptFormat")


class Role(str, Enum):
    """See https://platform.openai.com/docs/guides/chat"""

    SYSTEM = "system"
    USER = "user"
    ASSISTANT = "assistant"


class PromptBuilder:
    """Class for accumulating components of a prompt and then formatting
    them into an output."""

    def __init__(self) -> None:
        self._messages: list[HarmonyMessage] = []

    def add_message(self, role: Role, message: str) -> None:
        self._messages.append(HarmonyMessage(role=role, content=message))

    def prompt_length_in_tokens(self, prompt_format: PromptFormat) -> int:
        # TODO(sbills): Make the model/encoding configurable. This implementation
        #  assumes GPT-4.
        encoding = tiktoken.get_encoding("cl100k_base")
        if prompt_format == PromptFormat.HARMONY_V4:
            # Approximately-correct implementation adapted from this documentation:
            # https://platform.openai.com/docs/guides/chat/introduction
            num_tokens = 0
            for message in self._messages:
                num_tokens += (
                    4  # every message follows
                    # <|im_start|>{role/name}\n{content}<|im_end|>\n
                )
                num_tokens += len(encoding.encode(message["content"],
                                                   allowed_special="all"))
            num_tokens += 2  # every reply is primed with <|im_start|>assistant
            return num_tokens
        else:
            prompt_str = self.build(prompt_format)
            assert isinstance(prompt_str, str)
            return len(encoding.encode(prompt_str, allowed_special="all"))

    def build(
        self, prompt_format: PromptFormat, *, allow_extra_system_messages: bool = False
    ) -> Union[str, list[HarmonyMessage]]:
        """
        Validates the messages added so far (reasonable alternation of assistant
        vs. user, etc.) and returns either a regular string (maybe with
        <|endofprompt|> tokens) or a list of
        HarmonyMessages suitable for use with the /chat/completions endpoint.

        The `allow_extra_system_messages` parameter allows the caller to specify
        that the prompt should be allowed to contain system messages after the very
        first one.
        """
        # Create a deep copy of the messages so we can modify it and so that the
        # caller can't modify the internal state of this object.
        messages = [message.copy() for message in self._messages]

        expected_next_role = "system"
        for message in messages:
            role = message["role"]
            assert role == expected_next_role or (
                allow_extra_system_messages and role == "system"
            ), f"Expected message from {expected_next_role} but got message from {role}"
            if role == "system":
                expected_next_role = "user"
            elif role == "user":
                expected_next_role = "assistant"
            elif role == "assistant":
                expected_next_role = "user"

        if prompt_format == PromptFormat.INSTRUCTION_FOLLOWING:
            last_user_message = None
            for message in messages:
                if message["role"] == "user":
                    last_user_message = message
            assert last_user_message is not None
            last_user_message["content"] += "<|endofprompt|>"

        if prompt_format == PromptFormat.HARMONY_V4:
            return messages
        elif prompt_format in [PromptFormat.NONE, PromptFormat.INSTRUCTION_FOLLOWING]:
            return "".join(message["content"] for message in messages)
        else:
            raise ValueError(f"Unknown prompt format: {prompt_format}")
