from typing import override

from sentence_transformers import SentenceTransformer

from mow.common.data import prepare_graph_representation
from mow.dataset.auto import AutoChatDatasetBuilder
from mow.dataset.chat import ChatDatasetBuilder
from mow.dataset.history import ChatHistoryMixin

BASE_PROMPT = (
    "You are a home robot agent. You can use 9 skills (look, go to {obj}, open "
    "{obj}, close {obj}, take {obj} from {obj}, cool {obj} with {obj}, move "
    "{obj} to {obj}, heat {obj} with {obj}, examine {obj}, clean {obj} with "
    "{obj}, inventory, use {obj}). Room: livingroom, bathroom, kitchen, bedroom."
)


class AlfworldDatasetBuilder(ChatHistoryMixin, ChatDatasetBuilder):
    @override
    @classmethod
    def _convert_to_chat(
        cls, example, /, *, action_only: bool = False, **_
    ) -> list[dict[str, str]]:
        chat = [
            {"role": "system", "content": BASE_PROMPT},
            {
                "role": "user",
                "content": f"{example["history"][0]["observation"]}\n\nAction:",
            },
        ]
        for i, elem in enumerate(example["history"]):
            if i != 0:
                chat.append(
                    {"role": "user", "content": "Action: \n\n"},
                )
            chat.append(
                {"role": "assistant", "content": elem["action"]},
            )
            if i < len(example["history"]) - 1 or not action_only:
                chat.append(
                    {"role": "user", "content": "Observation: \n\n"},
                )
                chat.append(
                    {"role": "assistant", "content": elem["next_observation"]},
                )
        return chat

    @override
    @classmethod
    def _prepare_graph_representation(
        cls, example, /, *, sentence_transformer: SentenceTransformer
    ):
        return prepare_graph_representation(
            example["task"],
            (
                [
                    history[-1]["observation_graph"]
                    for history in example["history"]
                ]
                if isinstance(example["task"], list)
                else example["history"][-1]["observation_graph"]
            ),
            example["labels"] if "labels" in example else None,
            sentence_transformer=sentence_transformer,
        )


AutoChatDatasetBuilder.register("alfworld", AlfworldDatasetBuilder)
