"""
chat dataset module
"""
import os
from typing import Callable, Optional, Union

from datasets import Dataset
from transformers import PreTrainedTokenizer

from axolotl.core.chat.messages import ChatFormattedChats


class TokenizedChatDataset(Dataset):
    """
    Tokenized chat dataset
    """

    def __init__(
        self,
        data: Dataset,
        model_transform: Union[PreTrainedTokenizer, Callable],
        *args,
        message_transform: Optional[Callable] = None,
        formatter=None,
        process_count: Optional[int] = None,
        keep_in_memory: Optional[bool] = False,
        **kwargs,
    ):
        def map_fn(ex):
            if message_transform is not None:
                ex = message_transform(ex)
            if formatter is not None:
                ex = ChatFormattedChats(
                    formatter=formatter,
                    **ex,
                )
            else:
                ex = ChatFormattedChats(
                    **ex,
                )
            return ex.tokenized(model_transform)

        process_or_cpu_count: int = (
            process_count or os.cpu_count()  # type: ignore[assignment]
        )
        num_proc = min(64, process_or_cpu_count)
        features = data.features.keys()
        tokenized_data = data.map(
            map_fn,
            num_proc=num_proc,
            keep_in_memory=keep_in_memory,
            remove_columns=features,
            desc="Tokenizing Chats",
        )
        super().__init__(tokenized_data.data, *args, **kwargs)
