from typing import Iterable, List

from fastchat.conversation import get_conv_template

from ..message import Message, Role


class BaseModel:
    supports_system_message = False

    def __init__(self, **kwargs):
        raise NotImplementedError

    def __call__(self, messages: List[Message], api_key: str = None):
        raise NotImplementedError


class MockModel(BaseModel):
    """Testing model which returns the user's input as the response."""

    supports_system_message = False

    def __init__(self, **kwargs):
        pass

    def __call__(self, _, __):
        response = input("[Response]: ")
        return [response]


class UselessModel(BaseModel):
    supports_system_message = False

    def __init__(self, **kwargs):
        pass

    def __call__(self, messages: List[Message], api_key: str = None):
        return [f"I have ({len(messages)}) unread messages."]


def print_and_concat_stream(response: Iterable, role: Role = Role.ASSISTANT):
    chunks = []
    print(f"[{role.name.title()}]: ", end="", flush=True)
    for chunk in response:
        chunks.append(chunk)
        print(chunk, end="", flush=True)
    print("\n", end="", flush=True)
    return "".join(chunks)


def concat_stream(response: Iterable):
    return "".join(list(response))


def _simple_template(messages: List[Message]):
    texts = [
        "The following is a conversation between a user and an AI assistant. Please respond to the user as the assistant."
    ]
    for m in messages:
        texts.append(f"{m.role.name.title()}>{m.content}")
    texts.append(f"{Role.ASSISTANT.name.title()}>")
    return "\n".join(texts)


def build_prompt(messages: List[Message], template_name: str = None):
    if template_name is None:
        return _simple_template(messages)

    conv = get_conv_template(template_name)
    for m in messages:
        if m.role == Role.SYSTEM and m.content:
            conv.set_system_message(m.content)
        elif m.role == Role.USER:
            conv.append_message(conv.roles[0], m.content)
        elif m.role == Role.ASSISTANT:
            conv.append_message(conv.roles[1], m.content)
    conv.append_message(conv.roles[1], None)
    return conv.get_prompt()
