from __future__ import annotations

from abc import abstractmethod
from dataclasses import dataclass
from functools import partial

from sentence_transformers import SentenceTransformer
from transformers import PreTrainedTokenizerBase

import datasets
from mow.common.data import prepare_graph_representation
from mow.dataset.builder import DatasetBuilder
from mow.utils.types import instanceof


@dataclass
class ChatDatasetInfo(datasets.DatasetInfo):
    type: str = "chat"


class ChatDatasetBuilder(DatasetBuilder):
    def as_chat(
        self,
        tokenizer: PreTrainedTokenizerBase,
        *,
        batched: bool = False,
        desc: str | None = None,
        **kwargs,
    ):
        return self.map(
            partial(self.convert_to_chat, tokenizer=tokenizer, **kwargs),
            batched=batched,
            desc=desc or "Converting to chat format",
        )

    @classmethod
    def convert_to_chat(
        cls, example, /, *, tokenizer: PreTrainedTokenizerBase, **kwargs
    ) -> dict:
        """Convert a single example to chat format.

        This method should be implemented by subclasses to define how
        the example is converted to a chat format.
        """
        return {
            **example,
            "text": tokenizer.apply_chat_template(
                cls._convert_to_chat(example, **kwargs),
                tokenize=False,
            ),
        }

    @classmethod
    @abstractmethod
    def _convert_to_chat(cls, example, /, **kwargs) -> list[dict[str, str]]:
        raise NotImplementedError

    def prepare_graph_representation(
        self,
        /,
        *,
        sentence_transformer: SentenceTransformer,
        batched: bool = False,
        batch_size: int | None = None,
        desc: str | None = None,
    ):
        if batched:
            return self.map(
                partial(
                    self.__prepare_graph_representation_batched,
                    sentence_transformer=sentence_transformer,
                ),
                batched=batched,
                batch_size=batch_size,
                desc=desc or "Preparing graph representation",
            )
        return self.map(
            partial(
                self._prepare_graph_representation,
                sentence_transformer=sentence_transformer,
            ),
            batched=batched,
            batch_size=batch_size,
            desc=desc or "Preparing graph representation",
        )

    @classmethod
    def _prepare_graph_representation(
        cls,
        example,
        /,
        *,
        sentence_transformer: SentenceTransformer,
    ):
        return prepare_graph_representation(
            example["instruction"],
            example["observation"],
            example["labels"] if "labels" in example else None,
            sentence_transformer=sentence_transformer,
        )

    @classmethod
    def __prepare_graph_representation_batched(
        cls,
        examples,
        /,
        *,
        sentence_transformer: SentenceTransformer,
    ):
        ret = cls._prepare_graph_representation(
            examples,
            sentence_transformer=sentence_transformer,
        )
        length = len(next(iter(examples.values())))
        return (
            {k: [v] * length for k, v in ret.items()}
            if instanceof(ret, dict)
            else [ret] * length
        )
