import torch

from datasets import load_dataset
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizerBase
from typing import (
    Callable,
    Dict,
    List,
    Optional,
    Union,
)


class C4(Dataset):
    """ The C4 dataset class for [allenai/c4](https://huggingface.co/datasets/allenai/c4).
    """

    def __init__(
        self,
        config: Dict[str, Union[Dict, int, str]],
        dataset_config: Dict[str, str],
        mode: Optional[str] = None,
        local_rank: int = 0,
    ) -> None:
        """ Initialize the C4 dataset.

        Args:
            config (Dict[str, Union[Dict, int, str]]): The configuration.
            dataset_config (Dict[str, str]): The dataset configuration.
            mode (Optional[str], optional): The mode. Defaults to None.
            local_rank (int, optional): The local rank. Defaults to 0.
        """

        # Load the dataset.
        self.dataset = load_dataset(
            path=dataset_config['path'],
            data_dir=dataset_config['data_dir'],
            data_files=dataset_config['data_files'],
            split=dataset_config['split'],
        )
        original_data_number = len(self.dataset)

        # Prune the dataset to the specified number of samples.
        if dataset_config['data_number'] != -1:
            self.dataset = self.dataset.select(
                indices=list(range(dataset_config['data_number'])))
        pruned_data_number = len(self.dataset)

        self.message = f'C4 dataset has been initialized with {pruned_data_number} samples from {original_data_number} samples.'

    def __len__(self) -> int:
        """ Get the length of the dataset.

        Returns:
            int: The length of the dataset.
        """

        return len(self.dataset)

    def __getitem__(
        self,
        index: int,
    ) -> Dict[str, str]:
        """ Get a data.

        Args:
            index (int): The index of the data.

        Returns:
            Dict[str, str]: The data.
        """

        data = self.dataset[index]

        return {
            'input': data['text'],
            'label': data['text'],
        }

    @staticmethod
    def get_collate_fn(
        config: Dict[str, Union[Dict, int, str]],
        tokenizer: PreTrainedTokenizerBase,
    ) -> Callable:
        """ Get the collate function.

        Args:
            config (Dict[str, Union[Dict, int, str]]): The configuration.
            tokenizer (PreTrainedTokenizerBase): The tokenizer.

        Returns:
            Callable: The collate function.
        """

        def collate_fn(batch: List[Dict[str, str]]) -> Dict[str, torch.Tensor]:
            """ Collate the batch.

            Args:
                batch (List[Dict[str, str]]): The batch.

            Returns:
                Dict[str, torch.Tensor]: The collated batch.
            """

            # Try to initialize the multi-process logger.
            try:
                from ...logging import get_logger
                from ...utils import get_path

                logger = get_logger()
                source = f'{get_path(source_file=__file__)}.{collate_fn.__name__}'
            except:
                pass

            max_length = config['model']['max_length']

            # Uncomment if you want to check the length of tokens.
            # over_type = None

            # Tokenize the input text.
            batch_input = [data['input'] for data in batch]
            batch_input_tokens = tokenizer.batch_encode_plus(
                batch_text_or_text_pairs=batch_input,
                padding=config['model']['tokenizer']['padding'],
                # Uncomment if you want to check the length of the full tokens.
                # truncation=False,
                truncation=config['model']['tokenizer']['truncation'],
                max_length=max_length,
                return_tensors='pt',
            )

            ## Check the length of the input tokens.
            # Notice that the `batch_size` should be 1 and the `padding` should not be `max_length` when checking the length of tokens.
            # for index, input_tokens in enumerate(
            #         batch_input_tokens['input_ids']):
            #     if len(input_tokens) > max_length:
            #         over_type = 'input'

            #         message = f'The input tokens of {index} is longer than the maximum length with {len(input_tokens)} tokens.'

            #         try:
            #             logger.log(
            #                 message=message,
            #                 level='warning',
            #                 name=source,
            #             )
            #         except:
            #             print(message)

            #             pass
            ## -----

            # Create the label tokens.
            batch_label_tokens = batch_input_tokens['input_ids'].clone()

            return {
                'input_ids': batch_input_tokens['input_ids'],
                'attention_mask': batch_input_tokens['attention_mask'],
                'labels': batch_label_tokens,
                # Uncomment if you want to check the length of tokens.
                # 'over_type': over_type,
            }

        return collate_fn
