import copy
import random
from functools import partial
from typing import Callable, Dict, List, Optional

import torch
from datasets import load_dataset, DatasetDict, IterableDatasetDict
from torch import LongTensor
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizer


def make_data_block(
    samples: Dict[str, List[str]],
    prompt_col_name: str,
    label_col_name: str,
    tokenizer: PreTrainedTokenizer,
    preprocess_fn: Optional[Callable] = None,
    sample_max_len: int = 1024,
    block_max_len: int = 2048,
    add_eos_token: bool = False,
    truncate_prompt: bool = True,
    merge_prompt_label: bool = False,
) -> Dict[str, List[LongTensor]]:
    """A simple implementation of text generation oriented smart batching to maximize VRAM usage when evaluation

    :param samples: Dict[str, List[str]], samples that used to make data blocks
    :param prompt_col_name: str, name of the key in samples whose value stores prompt
    :param label_col_name: str, name of the key in samples whose value stores label
    :param tokenizer: transformers.PretrainedTokenizer, tokenizer that used to tokenize samples
    :param preprocess_fn: Optional[Callable], optional function that used to preprocess samples such as
        refactor the data structure of samples, note the output of this function must be a dict whose keys
        at least contains `prompt_col_name` and `label_col_name`
    :param sample_max_len: int, defaults to 1024, max tokens number of each sample (before padding)
    :param block_max_len: int, defaults to 2048, max tokens number of each data block (after padding)
    :param add_eos_token: bool, defaults to False, whether add eos_token or not to the label
    :param truncate_prompt: bool, defaults to True, whether to truncate prompt if the sample's total tokens
        number exceeds `sample_max_len`, if not, will truncate label and drop this sample when all tokens
        in label are truncated
    :param merge_prompt_label: bool, defaults to False, will merge label into prompt if set to True, usually
        this only required when doing language modeling task
    :return: Dict[str, List[torch.LongTensor]], a dict whose keys are `input_ids`, `attention_mask` and
        `label` and values are a list of torch.LongTensor
    """
    if preprocess_fn:
        samples = preprocess_fn(samples)

    prompts = samples[prompt_col_name]
    labels = samples[label_col_name]

    # tokenize samples
    tokenized_prompts = tokenizer(prompts, truncation=False)["input_ids"]
    tokenized_labels = tokenizer(labels, truncation=False)["input_ids"]

    # filter tokenized samples by length
    dropped_indices = []
    for idx, (tokenized_prompt, tokenized_label) in enumerate(
        zip(tokenized_prompts, tokenized_labels)
    ):
        if add_eos_token:
            tokenized_label += [tokenizer.eos_token_id]
        len_prompt = len(tokenized_prompt)
        len_label = len(tokenized_label)
        exceed_len = len_prompt + len_label - sample_max_len
        if exceed_len > 0:
            if truncate_prompt:
                tokenized_prompt = tokenized_prompt[exceed_len:]
            else:
                tokenized_label = tokenized_label[:-exceed_len]
        tokenized_prompts[idx] = tokenized_prompt
        tokenized_labels[idx] = tokenized_label
        if not tokenized_label:
            dropped_indices.append(idx)

    # make data blocks of samples
    tokenized_samples = sorted(
        [
            (p, l)
            for idx, (p, l) in enumerate(zip(tokenized_prompts, tokenized_labels))
            if idx not in dropped_indices
        ],
        key=lambda x: (len(x[0]) + len(x[1])) if merge_prompt_label else len(x[0]),
    )
    sample_blocks = []
    sample_block = []
    blk_max_len = 0
    blk_total_len = 0
    for tokenized_sample in tokenized_samples:
        prompt_ids, label_ids = tokenized_sample
        ori_sample_len = len(prompt_ids)
        if merge_prompt_label:
            ori_sample_len += len(label_ids)
        if ori_sample_len <= blk_max_len:
            additional_len = blk_max_len
            sample_len = blk_max_len
        else:
            additional_len = (
                len(sample_block) * (ori_sample_len - blk_max_len) + ori_sample_len
            )
            sample_len = ori_sample_len

        if blk_total_len + additional_len > block_max_len:
            sample_blocks.append((copy.copy(sample_block), blk_max_len))
            sample_block = []
            blk_max_len = 0
            blk_total_len = 0
            sample_len = ori_sample_len
            additional_len = ori_sample_len

        sample_block.append(tokenized_sample)
        blk_max_len = max(blk_max_len, sample_len)
        blk_total_len += additional_len

    if sample_block:
        sample_blocks.append((copy.copy(sample_block), blk_max_len))
    del sample_block
    del blk_max_len
    del blk_total_len

    new_samples = {"input_ids": [], "attention_mask": [], "labels": []}
    # padding each data block internally
    for block, blk_max_len in sample_blocks:
        input_ids = []
        attention_mask = []
        label_ids = []
        label_max_len = max([len(sample[1]) for sample in block])

        for sample in block:
            tokenized_prompt, tokenized_label = sample
            sample_len = len(tokenized_prompt)
            if merge_prompt_label:
                sample_len += len(tokenized_label)
            pad_num = blk_max_len - sample_len
            if merge_prompt_label:
                input_ids.append(
                    [tokenizer.pad_token_id] * pad_num
                    + tokenized_prompt
                    + tokenized_label
                )
                label_ids.append(
                    [-100] * (pad_num + len(tokenized_prompt)) + tokenized_label
                )
            else:
                input_ids.append([tokenizer.pad_token_id] * pad_num + tokenized_prompt)
                label_ids.append(
                    [-100] * (label_max_len - len(tokenized_label)) + tokenized_label
                )
            attention_mask.append([0] * pad_num + [1] * sample_len)

        new_samples["input_ids"].append(input_ids)
        new_samples["attention_mask"].append(attention_mask)
        new_samples["labels"].append(label_ids)

    return new_samples


def collate_data(
    blocks: List[Dict[str, List[List[int]]]], pad_token_id: int
) -> Dict[str, LongTensor]:
    def pad_block(block, pads):
        return torch.cat((pads.to(block.device), block), dim=-1)

    input_ids_blocks = [LongTensor(block["input_ids"]) for block in blocks]
    attention_mask_blocks = [LongTensor(block["attention_mask"]) for block in blocks]
    label_blocks = [LongTensor(block["labels"]) for block in blocks]

    bsz = len(blocks)
    inp_max_len = max([block.size(-1) for block in input_ids_blocks])
    label_max_len = max([block.size(-1) for block in label_blocks])

    for i in range(bsz):
        block_bsz, block_inp_len = input_ids_blocks[i].shape
        block_label_len = label_blocks[i].shape[-1]
        pad_num = inp_max_len - block_inp_len
        if pad_num > 0:
            input_ids_blocks[i] = pad_block(
                input_ids_blocks[i], torch.ones((block_bsz, pad_num)) * pad_token_id
            )
            attention_mask_blocks[i] = pad_block(
                attention_mask_blocks[i], torch.zeros((block_bsz, pad_num))
            )
        label_pad_num = label_max_len - block_label_len
        if label_pad_num > 0:
            label_blocks[i] = pad_block(
                label_blocks[i], torch.ones((block_bsz, label_pad_num)) * -100
            )

    return {
        "input_ids": torch.cat(input_ids_blocks, dim=0).long(),
        "attention_mask": torch.cat(attention_mask_blocks, dim=0).long(),
        "labels": torch.cat(label_blocks, dim=0).long(),
    }


def get_dataloader(
    data_path_or_name: str,
    prompt_col_name: str,
    label_col_name: str,
    tokenizer: PreTrainedTokenizer,
    load_fn: Optional[Callable] = None,
    preprocess_fn: Optional[Callable] = None,
    num_samples: int = 128,
    sample_max_len: int = 1024,
    block_max_len: int = 2048,
    add_eos_token: bool = False,
    truncate_prompt: bool = True,
    merge_prompt_label: bool = False,
    load_fn_kwargs: Optional[dict] = None,
    preprocess_fn_kwargs: Optional[dict] = None,
    **kwargs,
) -> DataLoader:
    """load dataset and build dataloader

    :param data_path_or_name: str, dataset name in hf-hub or local file path
    :param prompt_col_name: str, see `make_data_block`
    :param label_col_name: str, see `make_data_block`
    :param tokenizer: str, see `make_data_block`
    :param load_fn: Optional[Callable], defaults to None, function used to load dataset, if not specified,
        use `datasets.load_dataset`
    :param preprocess_fn: Optional[Callable], see `make_data_block`
    :param num_samples: int, defaults to 128, total samples used to evaluation
    :param sample_max_len: int, see `make_data_block`
    :param block_max_len: int, see `make_data_block`
    :param add_eos_token: bool, see `make_data_block`
    :param truncate_prompt: bool, see `make_data_block`
    :param merge_prompt_label: bool, see `make_data_block`
    :param load_fn_kwargs: Optional[dict], defaults to None, keyword arguments used
        for `load_fn` or `datasets.load_dataset`
    :param preprocess_fn_kwargs: Optional[dict], defaults to None, keyword arguments used
        for `preprocess_fn`
    :param kwargs: additional keyword arguments will be passed to torch's `DataLoader` initialization,
        note values of `batch_size`, `shuffle` and `collate_fn` will always be overridden to fixed value
    :return: torch.utils.data.DataLoader
    """

    if not load_fn_kwargs:
        load_fn_kwargs = dict()
    if not preprocess_fn_kwargs:
        preprocess_fn_kwargs = dict()

    if load_fn:
        ds = load_fn(data_path_or_name, **load_fn_kwargs)
    else:
        ds = load_dataset(data_path_or_name, **load_fn_kwargs)
    if isinstance(ds, (DatasetDict, IterableDatasetDict)):
        if "evaluation" in ds:
            ds = ds["evaluation"]
        elif "test" in ds:
            ds = ds["test"]
        else:
            ds = ds["train"]

    ds = ds.select(
        indices=random.sample(range(len(ds)), min(len(ds), num_samples)),
        keep_in_memory=True,
    )
    ds = ds.map(
        make_data_block,
        batched=True,
        batch_size=len(ds),
        num_proc=1,
        remove_columns=ds.column_names,
        keep_in_memory=True,
        load_from_cache_file=False,
        fn_kwargs={
            "prompt_col_name": prompt_col_name,
            "label_col_name": label_col_name,
            "tokenizer": tokenizer,
            "preprocess_fn": partial(preprocess_fn, **preprocess_fn_kwargs),
            "sample_max_len": sample_max_len,
            "block_max_len": block_max_len,
            "add_eos_token": add_eos_token,
            "truncate_prompt": truncate_prompt,
            "merge_prompt_label": merge_prompt_label,
        },
    )

    # override some arguments' values in kwargs despite user specified
    kwargs["batch_size"] = 1
    kwargs["shuffle"] = False
    kwargs["collate_fn"] = partial(collate_data, pad_token_id=tokenizer.pad_token_id)
    dl = DataLoader(ds, **kwargs)

    return dl


__all__ = ["make_data_block", "collate_data", "get_dataloader"]
