import os
import pickle
import torch

from datasets import load_dataset
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import PreTrainedTokenizerBase
from typing import (
    Callable,
    Dict,
    List,
    Optional,
    Union,
)


class LaMini(Dataset):
    """ The LaMini dataset class for [MBZUAI/LaMini-instruction](https://huggingface.co/datasets/MBZUAI/LaMini-instruction).
    """

    def __init__(
        self,
        config: Dict[str, Union[Dict, int, str]],
        dataset_config: Dict[str, str],
        mode: Optional[str] = None,
        local_rank: int = 0,
    ) -> None:
        """ Initialize the LaMini dataset.

        Args:
            config (Dict[str, Union[Dict, int, str]]): The configuration.
            dataset_config (Dict[str, str]): The dataset configuration.
            mode (Optional[str], optional): The mode. Defaults to None.
            local_rank (int, optional): The local rank. Defaults to 0.
        """

        # Try to initialize the multi-process logger.
        try:
            from ...logging import get_logger
            from ...utils import get_path

            logger = get_logger()
            source = f'{get_path(source_file=__file__)}.{LaMini.__init__.__name__}'
        except:
            pass

        # Load the dataset.
        self.dataset = load_dataset(
            path=dataset_config['path'],
            split=dataset_config['split'],
        )
        original_data_number = len(self.dataset)

        ## Prune the dataset.
        data_indices_to_remove = []
        try:
            root_path = 'ds_src/initialize/datasets'
            model_name = config['model']['base_name'].split('/')[-1]
            dataset_name = config[mode]['dataset']['name']

            with open(
                    file=os.path.join(
                        root_path,
                        f'{model_name}_{dataset_name}_{mode}_token_count.pkl',
                    ),
                    mode='rb',
            ) as f:
                data_tokens = pickle.load(file=f)
                f.close()

            for data_idx, data_token in enumerate(iterable=data_tokens):
                if data_token >= config['model']['max_length']:
                    data_indices_to_remove.append(data_idx)

            message_0 = f'There are {len(data_tokens)} data in the dataset.'
            message_1 = f'There are {len(data_indices_to_remove)} data that exceed the maximum sequence length ({config["model"]["max_length"]}).'
            try:
                logger.log(
                    message=message_0,
                    source=source,
                )
                logger.log(
                    message=message_1,
                    source=source,
                )
            except:
                print(message_0)
                print(message_1)
        except Exception as e:
            # Only for debugging.
            raise e

            # If you want to use full dataset without pruning, uncomment the following line.
            # pass

        data_indices_to_keep = []
        subsets_data_number = {}
        subsets_data_number_limit = dataset_config['subset_data_number']

        for data_source in subsets_data_number_limit.keys():
            subsets_data_number[data_source] = 0

        for index, data in enumerate(iterable=tqdm(
                iterable=self.dataset,
                desc='[Filtering & Pruning Dataset]',
                disable=True if local_rank > 0 else False,
                dynamic_ncols=True,
        )):
            if index in data_indices_to_remove:
                continue

            data_source = data['instruction_source']
            if subsets_data_number[data_source] < subsets_data_number_limit[
                    data_source]:
                data_indices_to_keep.append(index)
                subsets_data_number[data_source] += 1

        self.dataset = self.dataset.select(indices=data_indices_to_keep)
        pruned_data_number = len(self.dataset)
        ## -----

        self.message = f'LaMini dataset has been initialized with {pruned_data_number} samples from {original_data_number} samples.'

    def __len__(self) -> int:
        """ Get the length of the dataset.

        Returns:
            int: The length of the dataset.
        """

        return len(self.dataset)

    def __getitem__(
        self,
        index: int,
    ) -> Dict[str, str]:
        """ Get a data.

        Args:
            index (int): The index of the data.

        Returns:
            Dict[str, str]: The data.
        """

        data = self.dataset[index]

        return {
            'input': self.__get_prompt(data=data),
            'label': self.__get_prompt(
                data=data,
                label=data['response'],
            ),
            'source': data['instruction_source'],
        }

    def __get_prompt(
        self,
        data: Dict[str, str],
        label: Optional[str] = None,
    ) -> str:
        """ Get a prompt.
        Reference: [afaji/lm-evaluation-harness/lm_eval/evaluator.py](https://github.com/afaji/lm-evaluation-harness/blob/master/lm_eval/evaluator.py#L222-L227)

        Args:
            data (Dict[str, str]): The data.
            label (Optional[str], optional): The label. Defaults to None.

        Returns:
            str: The prompt.
        """

        prompt = f'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{data["instruction"]}\n\n### Response:\n'

        if label:
            prompt += label

        return prompt

    @staticmethod
    def get_collate_fn(
        config: Dict[str, Union[Dict, int, str]],
        tokenizer: PreTrainedTokenizerBase,
    ) -> Callable:
        """ Get the collate function.

        Args:
            config (Dict[str, Union[Dict, int, str]]): The configuration.
            tokenizer (PreTrainedTokenizerBase): The tokenizer.

        Returns:
            Callable: The collate function.
        """

        def collate_fn(batch: List[Dict[str, str]]) -> Dict[str, torch.Tensor]:
            """ Collate the batch.

            Args:
                batch (List[Dict[str, str]]): The batch.

            Returns:
                Dict[str, torch.Tensor]: The collated batch.
            """

            # Try to initialize the multi-process logger.
            try:
                from ...logging import get_logger
                from ...utils import get_path

                logger = get_logger()
                source = f'{get_path(source_file=__file__)}.{collate_fn.__name__}'
            except:
                pass

            max_length = config['model']['max_length']

            # Uncomment if you want to check the length of tokens.
            # over_type = None

            # Tokenize the input text.
            batch_input = [data['input'] for data in batch]
            batch_input_tokens = tokenizer.batch_encode_plus(
                batch_text_or_text_pairs=batch_input,
                padding=False,
                return_tensors='np',
            )

            ## Check the length of the input tokens.
            # Notice that the `batch_size` should be 1 and the `padding` should not be `max_length` when checking the length of tokens.
            # for index, input_tokens in enumerate(
            #         batch_input_tokens['input_ids']):
            #     if len(input_tokens) > max_length:
            #         over_type = 'input'

            #         message = f'The input tokens of {index} is longer than the maximum length with {len(input_tokens)} tokens.'

            #         try:
            #             logger.log(
            #                 message=message,
            #                 level='warning',
            #                 name=source,
            #             )
            #         except:
            #             print(message)

            #             pass
            ## -----

            # Tokenize the full text.
            batch_full = [data['label'] for data in batch]
            batch_full_tokens = tokenizer.batch_encode_plus(
                batch_text_or_text_pairs=batch_full,
                padding=config['model']['tokenizer']['padding'],
                # Uncomment if you want to check the length of the full tokens.
                # truncation=False,
                truncation=config['model']['tokenizer']['truncation'],
                max_length=max_length,
                return_tensors='pt',
            )

            ## Check the length of the full tokens.
            # for index, full_tokens in enumerate(batch_full_tokens['input_ids']):
            #     if len(full_tokens) > max_length:
            #         over_type = 'full' if over_type is None else 'both'

            #         message = f'The full tokens of {index} is longer than the maximum length with {len(full_tokens)} tokens.'

            #         try:
            #             logger.log(
            #                 message=message,
            #                 level='warning',
            #                 name=source,
            #             )
            #         except:
            #             print(message)

            #             pass
            ## -----

            # Create the label tokens.
            batch_label_tokens = None
            for input_tokens, full_tokens in zip(
                    batch_input_tokens['input_ids'],
                    batch_full_tokens['input_ids'],
            ):
                label_tokens = torch.clone(input=full_tokens)

                # Set the input tokens to -100.
                label_tokens[:len(input_tokens)] = -100

                # Set the padding tokens to -100.
                for index, label_token in enumerate(iterable=label_tokens):
                    if label_token == tokenizer.pad_token_id:
                        label_tokens[index] = -100

                label_tokens = torch.unsqueeze(
                    input=label_tokens,
                    dim=0,
                )

                batch_label_tokens = label_tokens \
                    if batch_label_tokens is None \
                        else torch.cat(
                            tensors=(
                                batch_label_tokens,
                                label_tokens,
                            ),
                            dim=0,
                        )

            return {
                'input_ids': batch_full_tokens['input_ids'],
                'attention_mask': batch_full_tokens['attention_mask'],
                'labels': batch_label_tokens,
                # Uncomment if you want to check the length of tokens.
                # 'over_type': over_type,
            }

        return collate_fn
