import os
import random
import torch

from datasets import load_dataset
from torch.utils.data import (
    DataLoader,
    Dataset,
)
from tqdm import tqdm
from transformers import PreTrainedTokenizerBase
from typing import (
    Dict,
    List,
    Tuple,
)


def get_calibration_dataset(
    tokenizer: PreTrainedTokenizerBase,
    data_name: str,
    data_number: int,
    sequence_length: int,
    batch_size: int,
    cache_dir: str,
    local_rank: int,
    only_return_name: bool = False,
) -> Tuple[List[Dict[str, torch.Tensor]], str]:
    """ Get the calibration dataset.

    Args:
        tokenizer (PreTrainedTokenizerBase): The tokenizer.
        data_name (str): The name of calibration dataset.
        data_number (int): The number of samples to be generated.
        sequence_length (int): The sequence length. It is recommended to be half of the model maximum sequence length.
        batch_size (int): The batch size.
        cache_dir (str): The cache directory.
        local_rank (int): The local rank of the process.
        only_return_name (bool): If True, only return the name of the calibration dataset.

    Raises:
        NotImplementedError: If the dataset is NOT implemented.

    Returns:
        Tuple[List[Dict[str, torch.Tensor]], str]: The calibration dataset and its name.
    """

    calibration_dataset_name = \
        f'{data_name}_{data_number}_{sequence_length}_{batch_size}'
    cache_path = os.path.join(
        cache_dir,
        f'{calibration_dataset_name}.pt',
    )

    # Try to load the calibration dataset from the cache.
    if os.path.exists(path=cache_path):
        calibration_dataset = None

        if not only_return_name:
            calibration_dataset = torch.load(
                f=cache_path,
                weights_only=True,
            )

        return (
            calibration_dataset,
            calibration_dataset_name,
        )

    # Load the source dataset.
    match data_name:
        case 'alpaca':
            _dataset = load_dataset(
                path='yahma/alpaca-cleaned',
                split='train',
            )

            template_dict = {
                'input':
                'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n',
                'no_input':
                'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n',
            }

            text_column = []
            for data_idx in tqdm(
                    iterable=range(len(_dataset)),
                    desc='[Concatenating the texts]',
                    dynamic_ncols=True,
            ):
                if _dataset[data_idx]['input']:
                    text = template_dict['input'].format(
                        instruction=_dataset[data_idx]['instruction'],
                        input=_dataset[data_idx]['input'],
                    )
                else:
                    text = template_dict['no_input'].format(
                        instruction=_dataset[data_idx]['instruction'])

                text += _dataset[data_idx]['output']
                text_column.append(text)

            _dataset = _dataset.add_column(
                name='text',
                column=text_column,
            )
        case 'c4':
            _dataset = load_dataset(
                path='allenai/c4',
                data_files={'train': 'en/c4-train.00000-of-01024.json.gz'},
                split='train',
            )
        case 'openbookqa':
            _dataset = load_dataset(
                path='allenai/openbookqa',
                data_dir='main',
                split='train',
            )

            text_column = []
            for data_idx in tqdm(
                    iterable=range(len(_dataset)),
                    desc='[Concatenating the texts]',
                    dynamic_ncols=True,
            ):
                text = f'Question: {_dataset[data_idx]["question_stem"]}\n\nAnswer Choices:\n'

                _choices = _dataset[data_idx]['choices']['text']
                _labels = _dataset[data_idx]['choices']['label']
                for _label, _choice in zip(
                        _labels,
                        _choices,
                ):
                    text += f'({_label}) {_choice}\n'

                text += '\nAnswer: '

                for _label, _choice in zip(
                        _labels,
                        _choices,
                ):
                    if _label == _dataset[data_idx]['answerKey']:
                        text += f'({_label}) {_choice}'

                text_column.append(text)

            _dataset = _dataset.add_column(
                name='text',
                column=text_column,
            )
        case 'piqa':
            _dataset = load_dataset(
                path='ybisk/piqa',
                split='train',
                trust_remote_code=True,
            )

            text_column = []
            for data_idx in tqdm(
                    iterable=range(len(_dataset)),
                    desc='[Concatenating the texts]',
                    dynamic_ncols=True,
            ):
                goal = _dataset[data_idx]['goal']
                solution = _dataset[data_idx]['sol1'] \
                    if _dataset[data_idx]['label'] == 0 \
                        else _dataset[data_idx]['sol2']

                text_column.append(f'Question: {goal}\nAnswer: {solution}')

            _dataset = _dataset.add_column(
                name='text',
                column=text_column,
            )
        case 'wikitext2':
            _dataset = load_dataset(
                path='Salesforce/wikitext',
                data_dir='wikitext-2-raw-v1',
                split='train',
            )
        case _:
            raise NotImplementedError(
                f'The dataset {data_name} is NOT implemented yet.')

    texts = '\n\n'.join(_dataset['text'])

    ## Check the number of samples and characters.
    # print(f'The number of samples: {len(_dataset)}')
    # print(f'The number of characters: {len(texts)}')
    # input('Press Enter to continue.')
    ## -----

    # Build the calibration dataset.
    calibration_dataset = []
    input_ids = None
    for data_idx in range(data_number):
        i = random.randint(
            a=0,
            b=(len(texts) - sequence_length - 1),
        )
        j = i + sequence_length * 10

        encodings = tokenizer(
            text=texts[i:j],
            return_tensors='pt',
        )

        if encodings.input_ids.shape[1] < sequence_length:
            data_idx -= 1
            continue

        if data_idx % batch_size == 0:
            if data_idx != 0:
                attention_mask = torch.ones_like(input=input_ids)
                calibration_dataset.append({
                    'input_ids': input_ids,
                    'attention_mask': attention_mask,
                })

            input_ids = encodings.input_ids[:, :sequence_length]
        else:
            input_ids = torch.cat(
                tensors=(input_ids, encodings.input_ids[:, :sequence_length]),
                dim=0,
            )

    # Save the calibration dataset to the cache.
    if local_rank == 0:
        torch.save(
            obj=calibration_dataset,
            f=cache_path,
        )

    return (
        calibration_dataset,
        calibration_dataset_name,
    )


def get_test_dataset(
    tokenizer: PreTrainedTokenizerBase,
    data_name: str,
    sequence_length: int,
    batch_size: int,
) -> DataLoader:
    """ Get the test dataloader.

    Args:
        tokenizer (PreTrainedTokenizerBase): The tokenizer.
        data_name (str): The name of the test dataset.
        sequence_length (int): The sequence length.
        batch_size (int): The batch size.

    Raises:
        NotImplementedError: If the dataset is NOT implemented.

    Returns:
        DataLoader: The test dataloader.
    """

    class BasicDataset(Dataset):

        def __init__(
            self,
            tensors: torch.Tensor,
        ):
            self.tensors = tensors

        def __getitem__(
            self,
            index: int,
        ) -> torch.Tensor:
            return self.tensors[index]

        def __len__(self) -> int:
            return len(self.tensors)

    def process_data(
        tokenizer: PreTrainedTokenizerBase,
        dataset: Dataset,
        sequence_length: int,
        data_key: str,
    ) -> BasicDataset:
        """ Process the dataset.

        Args:
            tokenizer (PreTrainedTokenizerBase): The tokenizer.
            dataset (Dataset): The dataset.
            sequence_length (int): The sequence length.
            data_key (str): The key of the dataset.

        Returns:
            BasicDataset: The processed dataset.
        """

        texts = '\n\n'.join(dataset[data_key])
        input_ids = tokenizer(
            text=texts,
            return_tensors='pt',
        ).input_ids[0]

        batches_input_ids = []
        data_number = input_ids.numel() // sequence_length
        for data_idx in range(data_number):
            batch_input_ids = input_ids[(data_idx *
                                         sequence_length):((data_idx + 1) *
                                                           sequence_length)]
            batches_input_ids.append(batch_input_ids)

        batches_input_ids = torch.stack(tensors=batches_input_ids)

        return BasicDataset(tensors=batches_input_ids)

    match data_name:
        case 'c4':
            dataset = load_dataset(
                path='json',
                data_files='ds_src/SVDLLM/test_datasets/c4-validation.json',
            )['train']
            dataset = process_data(
                dataset=dataset[0:2000],
                tokenizer=tokenizer,
                sequence_length=sequence_length,
                data_key='text',
            )
        case 'wikitext2':
            dataset = load_dataset(
                # path='Salesforce/wikitext',
                path='wikitext',
                data_dir='wikitext-2-raw-v1',
                split='test',
            )
            dataset = process_data(
                dataset=dataset,
                tokenizer=tokenizer,
                sequence_length=sequence_length,
                data_key='text',
            )
        case 'arc_challenge':
            dataset = load_dataset(
                path='json',
                data_files='ds_src/SVDLLM/test_datasets/arc_challenge_test.json',
            )['train']
            dataset = process_data(
                dataset=dataset,
                tokenizer=tokenizer,
                sequence_length=sequence_length,
                data_key='question',
            )
        case _:
            raise NotImplementedError(
                f'The dataset {data_name} is NOT implemented yet.')

    dataloader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=False,
    )

    return dataloader
