# modified from https://github.com/hpcaitech/ColossalAI-Examples/tree/main/language/bert/sequene_parallel

import os
from itertools import chain

import torch
from datasets import load_dataset, load_from_disk
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from transformers import AutoTokenizer, DataCollatorForLanguageModeling

from .distributed_utils import DistGroups
from .tasks.base.dataset import broadcast_data, get_sp_method


def get_mlm_batch(data_iterator):
    """Build the batch."""

    # Items and their type.
    # keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask']
    keys = ["input_ids", "token_type_ids", "attention_mask", "labels"]
    datatype = torch.int64
    # Broadcast data.
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None

    # unpack
    data_b = broadcast_data(keys, data, datatype)

    # # get tensor parallel local rank
    local_world_size = DistGroups["sp"].size()
    local_rank = DistGroups["sp"].rank()
    seq_length = data_b["input_ids"].size(1)
    sub_seq_length = seq_length // local_world_size
    sub_seq_start = local_rank * sub_seq_length
    sub_seq_end = (local_rank + 1) * sub_seq_length

    # # Unpack.
    # tokens = data_b['text'][:, sub_seq_start:sub_seq_end].long()
    # types = data_b['types'][:, sub_seq_start:sub_seq_end].long()
    # sentence_order = data_b['is_random'].long()
    # loss_mask = data_b['loss_mask'][:, sub_seq_start:sub_seq_end].float()
    # lm_labels = data_b['labels'][:, sub_seq_start:sub_seq_end].long()
    # padding_mask = data_b['padding_mask'].long()

    input_ids = data_b["input_ids"][:, sub_seq_start:sub_seq_end].long().contiguous()
    token_type_ids = data_b["token_type_ids"][:, sub_seq_start:sub_seq_end].long().contiguous()
    sp_method = get_sp_method()
    if sp_method in ("colai", "single", "megatron") or sp_method == "megatron":
        attention_mask = data_b["attention_mask"].long().contiguous()
    elif sp_method.startswith("qasp"):
        attention_mask = data_b["attention_mask"][:, sub_seq_start:sub_seq_end].long().contiguous()
    else:
        raise ValueError
    labels = data_b["labels"][:, sub_seq_start:sub_seq_end].long().contiguous()
    # handle position_ids
    position_ids = (
        torch.arange(sub_seq_start, sub_seq_end)
        .expand((data_b["input_ids"].size(0), sub_seq_length))
        .long()
        .contiguous()
    )
    # position_ids = full_position_ids[:, sub_seq_start:sub_seq_end].long().contiguous()

    return {
        "input_ids": input_ids,
        "token_type_ids": token_type_ids,
        "attention_mask": attention_mask,
        "labels": labels,
        "position_ids": position_ids,
    }


def get_megatron_lm_batch(data_iterator):
    """Build the batch."""

    # Items and their type.
    # keys = ['text', 'types', 'labels', 'is_random', 'loss_mask', 'padding_mask']
    keys = ["input_ids", "token_type_ids", "attention_mask", "labels"]
    datatype = torch.int64
    # Broadcast data.
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None

    # unpack
    data_b = broadcast_data(keys, data, datatype)

    # # get tensor parallel local rank
    seq_length = data_b["input_ids"].size(1)

    # # Unpack.

    input_ids = data_b["input_ids"].long().contiguous()
    token_type_ids = data_b["token_type_ids"].long().contiguous()
    attention_mask = data_b["attention_mask"].long().contiguous()
    labels = data_b["labels"].long().contiguous()
    # handle position_ids
    position_ids = (
        torch.arange(seq_length)
        # .expand((data_b["input_ids"].size(0), sub_seq_length))
        .long().contiguous()
    )
    # position_ids = full_position_ids[:, sub_seq_start:sub_seq_end].long().contiguous()

    return {
        "bert_model_input": input_ids,
        "tokentype_ids": token_type_ids,
        "attention_mask": attention_mask,
        "lm_labels": labels,
        "position_ids": position_ids,
    }
    # model_forward_map = {
    #     # MegatronKey-dataloadKey
    #     "bert_model_input": "input_ids",
    #     "attention_mask": "attention_mask",
    #     "tokentype_ids": "token_type_ids",
    #     "lm_labels": "labels",
    #     "position_ids": "position_ids"
    # }
    # record = get_batch_for_sequence_parallel(data_iterator)
    # return {megatron_key: record[dataload_key] for megatron_key, dataload_key
    #         in model_forward_map.items()}


def get_glue_batch(data_iterator):
    keys = ["input_ids", "token_type_ids", "attention_mask", "labels"]
    datatype = torch.int64
    # Broadcast data.
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None

    # unpack
    data_b = broadcast_data(keys, data, datatype)
    # # get tensor parallel local rank
    local_world_size = DistGroups["sp"].size()
    local_rank = DistGroups["sp"].rank()
    seq_length = data_b["input_ids"].size(1)
    sub_seq_length = seq_length // local_world_size
    sub_seq_start = local_rank * sub_seq_length
    sub_seq_end = (local_rank + 1) * sub_seq_length

    # # Unpack.
    # tokens = data_b['text'][:, sub_seq_start:sub_seq_end].long()
    # types = data_b['types'][:, sub_seq_start:sub_seq_end].long()
    # sentence_order = data_b['is_random'].long()
    # loss_mask = data_b['loss_mask'][:, sub_seq_start:sub_seq_end].float()
    # lm_labels = data_b['labels'][:, sub_seq_start:sub_seq_end].long()
    # padding_mask = data_b['padding_mask'].long()

    input_ids = data_b["input_ids"][:, sub_seq_start:sub_seq_end].long().contiguous()
    token_type_ids = data_b["token_type_ids"][:, sub_seq_start:sub_seq_end].long().contiguous()
    sp_method = get_sp_method()
    if sp_method in ("colai", "megatron", "single"):
        attention_mask = data_b["attention_mask"].long().contiguous()
    elif sp_method.startswith("qasp"):
        attention_mask = data_b["attention_mask"][:, sub_seq_start:sub_seq_end].long().contiguous()
    else:
        raise ValueError
    labels = data_b["labels"]  # (batch,) have no onehot
    # handle position_ids
    position_ids = (
        torch.arange(sub_seq_start, sub_seq_end)
        .expand((data_b["input_ids"].size(0), sub_seq_length))
        .long()
        .contiguous()
    )
    # position_ids = full_position_ids[:, sub_seq_start:sub_seq_end].long().contiguous()

    return {
        "input_ids": input_ids,
        "token_type_ids": token_type_ids,
        "attention_mask": attention_mask,
        "labels": labels,
        "position_ids": position_ids,
    }


def get_pathx_batch(data_iterator):
    """Build the batch."""

    # Items and their type.
    keys = ["input_ids", "labels"]
    datatype = torch.int64
    # Broadcast data.
    if data_iterator is not None:
        data = next(data_iterator)
    else:
        data = None

    # unpack
    data_b = broadcast_data(keys, data, datatype)

    # # get tensor parallel local rank
    local_world_size = DistGroups["sp"].size()
    local_rank = DistGroups["sp"].rank()
    seq_length = data_b["input_ids"].size(1)
    sub_seq_length = seq_length // local_world_size
    sub_seq_start = local_rank * sub_seq_length
    sub_seq_end = (local_rank + 1) * sub_seq_length

    input_ids = data_b["input_ids"][:, sub_seq_start:sub_seq_end].long().contiguous()
    labels = data_b["labels"]

    # handle position_ids
    position_ids = (
        torch.arange(sub_seq_start, sub_seq_end)
        .expand((data_b["input_ids"].size(0), sub_seq_length))
        .long()
        .contiguous()
    )

    return {
        "input_ids": input_ids,
        "labels": labels,
        "position_ids": position_ids,
    }


class InfinityDataloader(object):
    def __init__(self, dataloader, shuffle=False):
        self.dataloader = dataloader
        self.epoch = 0
        self.shuffle = shuffle

    def __iter__(
        self,
    ):
        # return cycle(self.dataloader)  # better
        while 1:
            if self.shuffle:  # isinstance?
                self.dataloader.sampler.set_epoch(self.epoch)
            for record in self.dataloader:
                yield record
            self.epoch += 1

    def __len__(self):
        return len(self.dataloader)


def get_mlm_dataloader(
    datasets_path=None,
    train_file=None,
    model_name_or_path=None,
    validation_split_percentage=5,
    preprocessing_num_workers=None,
    mlm_probability=0.15,
    per_device_train_batch_size=10,
    per_device_eval_batch_size=10,
    max_seq_length=8192,
):
    if train_file is not None:
        extension = train_file.split(".")[-1]
        data_files = {}
        if train_file is not None:
            data_files["train"] = train_file

        extension = train_file.split(".")[-1]
        if extension == "txt":
            extension = "text"
        raw_datasets = load_dataset(extension, data_files=data_files)
        # If no validation data is there, validation_split_percentage will be used to divide the dataset.
        if "validation" not in raw_datasets.keys():
            raw_datasets["validation"] = load_dataset(
                extension,
                data_files=data_files,
                split=f"train[:{validation_split_percentage}%]",
            )
            raw_datasets["train"] = load_dataset(
                extension,
                data_files=data_files,
                split=f"train[{validation_split_percentage}%:]",
            )
    elif datasets_path is not None:
        raw_datasets = load_from_disk(datasets_path)

    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
    if DistGroups["dp"].rank() == 0:
        print("tokenizer", tokenizer)
    #####
    column_names = raw_datasets["train"].column_names
    text_column_name = "text" if "text" in column_names else column_names[0]

    def tokenize_function(examples):
        return tokenizer(examples[text_column_name], return_special_tokens_mask=True)

    tokenized_datasets = raw_datasets.map(
        tokenize_function,
        batched=True,
        num_proc=preprocessing_num_workers,
        remove_columns=column_names,
        load_from_cache_file=True,
        desc="Running tokenizer on every text in dataset",
    )

    # Main data processing function that will concatenate all texts from our dataset and generate chunks of
    # max_seq_length.
    def group_texts(examples):
        # Concatenate all texts.
        concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
        if total_length >= max_seq_length:
            total_length = (total_length // max_seq_length) * max_seq_length
        # Split by chunks of max_len.
        result = {
            k: [t[slice(i, i + max_seq_length)] for i in range(0, total_length, max_seq_length)]
            for k, t in concatenated_examples.items()
        }
        return result

    # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a
    # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value
    # might be slower to preprocess.
    #
    # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
    # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
    cache_file_names = {
        "train": os.path.join(datasets_path, "train", "mlm-%s.arrow" % (max_seq_length)),
        "validation": os.path.join(datasets_path, "validation", "mlm-%s.arrow" % (max_seq_length)),
        "test": os.path.join(datasets_path, "test", "mlm-%s.arrow" % (max_seq_length)),
    }

    tokenized_datasets = tokenized_datasets.map(
        group_texts,
        batched=True,
        num_proc=preprocessing_num_workers,
        load_from_cache_file=True,
        cache_file_names=cache_file_names,
        desc=f"Grouping texts in chunks of {max_seq_length}"
        if DistGroups["dp"].rank() == 0 and DistGroups["sp"].rank() == 0
        else None,
    )

    train_dataset = tokenized_datasets["train"]
    eval_dataset = tokenized_datasets["validation"]
    test_dataset = tokenized_datasets["test"]

    ##########
    # for index in random.sample(range(len(train_dataset)), 3):
    #     print(f"Sample {index} of the training set: {train_dataset[index]}.")

    ########
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=mlm_probability)

    # DataLoaders creation:
    print("sampler", DistGroups["dp"].size(), DistGroups["dp"].rank(), len(train_dataset), len(eval_dataset))
    sampler = DistributedSampler(
        train_dataset, shuffle=True, num_replicas=DistGroups["dp"].size(), rank=DistGroups["dp"].rank()
    )
    train_dataloader = DataLoader(
        train_dataset,
        shuffle=sampler is None,
        sampler=sampler,
        collate_fn=data_collator,
        batch_size=per_device_train_batch_size,
        pin_memory=True,
    )

    eval_dataloader = DataLoader(
        eval_dataset,
        collate_fn=data_collator,
        batch_size=per_device_eval_batch_size,
        pin_memory=True,
    )
    test_dataloader = DataLoader(
        test_dataset,
        collate_fn=data_collator,
        batch_size=per_device_eval_batch_size,
        pin_memory=True,
    )
    # dataload state
    # seq_lengths = [record["input_ids"].size(1) for record in train_dataloader]
    # train_df = pd.DataFrame(seq_lengths, columns=["seq_length"])
    # print("train_df", train_df.describe([0.1 * i for i in range(10)]))
    # seq_lengths = [record["input_ids"].size(1) for record in eval_dataloader]
    # eval_df = pd.DataFrame(seq_lengths, columns=["seq_length"])
    # print("eval_df", eval_df.describe([0.1 * i for i in range(10)]))
    # return train_dataloader, eval_dataloader, tokenizer
    return train_dataloader, eval_dataloader, test_dataloader, tokenizer
