from functools import partial
import os
import numpy as np
import torch
from transformers import DataCollatorForLanguageModeling, BatchEncoding
from latte_trans.preproc.base import DataProcessing
from datasets import Dataset, DatasetDict, load_dataset, concatenate_datasets


def np_collate_fn(max_seq_len, original_batch):
    input_ids = np.stack([x["input_ids"][:max_seq_len] for x in original_batch])
    labels = np.copy(input_ids)
    return BatchEncoding({"input_ids": input_ids, "labels": labels})


def torch_collate_fn(max_seq_len, original_batch):
    input_ids = torch.stack([x["input_ids"][:max_seq_len] for x in original_batch])
    labels = torch.clone(input_ids)

    if "attention_mask" in original_batch[0]:
        attention_mask = torch.stack(
            [x["attention_mask"][:max_seq_len] for x in original_batch]
        )
        return BatchEncoding(
            {"input_ids": input_ids, "labels": labels, "attention_mask": attention_mask}
        )
    return BatchEncoding({"input_ids": input_ids, "labels": labels})


class OpenWeb(DataProcessing):
    """Group examples such that padding is reduced to minimal"""

    def __init__(self, tokenizer, cache_dir, num_load_procs, max_seq_len) -> None:
        self._cache_dir = cache_dir
        self._tokenizer = tokenizer
        self._num_load_procs = num_load_procs
        self.max_seq_len = max_seq_len

    def get_collate_fn(self, return_type="torch", **kwargs):
        if return_type == "torch":
            return partial(torch_collate_fn, self.max_seq_len)
        else:
            return partial(np_collate_fn, self.max_seq_len)

    @property
    def tokenizer(self):
        return self._tokenizer

    def _tokenize(self, max_seq_len, elements):
        elements = self._tokenizer(
            elements["text"],
            return_special_tokens_mask=False,
            add_special_tokens=True,
            truncation=False,
        )
        # group elements together to reduce padding
        features = list(elements.keys())
        trailing_sequence = {k: [] for k in features}
        new_elements = {k: [] for k in features}
        for i in range(len(elements[features[0]])):
            trail_len = len(trailing_sequence[features[0]])
            if trail_len >= max_seq_len:
                tmp = {k: trailing_sequence[k][0:max_seq_len] for k in features}
                for k in features:
                    new_elements[k].append(tmp[k])
                # reset trailing features
                trailing_sequence = {k: elements[k][i][max_seq_len:] for k in features}
            for k in features:
                trailing_sequence[k].extend(elements[k][i])

        return new_elements

    def tokenize(self, raw_dataset, force_tok=True):
        # pre-tokenize val data
        tok_data_val = raw_dataset.map(
            partial(self._tokenize, self.max_seq_len),
            batched=True,
            num_proc=self._num_load_procs,
            batch_size=10000,
            remove_columns=["text"],
            cache_file_names={
                "train": os.path.join(self._cache_dir, "tok_openweb_train.bin"),
                "validation": os.path.join(self._cache_dir, "tok_openweb_val.bin"),
            },
            load_from_cache_file=not force_tok,
        )
        return tok_data_val

    def get_raw_data(self):
        os.makedirs(self._cache_dir, exist_ok=True)
        dataset = load_dataset(
            "openwebtext", cache_dir=self._cache_dir, num_proc=self._num_load_procs
        )
        # original data only has train
        split_dataset = dataset["train"].train_test_split(
            test_size=0.0005, seed=2357, shuffle=True
        )
        split_dataset["validation"] = split_dataset.pop(
            "test"
        )  # rename the test split to val

        # split_dataset["train"] = split_dataset["train"].select(np.arange(100))
        # split_dataset["validation"] = split_dataset["validation"].select(np.arange(100))
        # split_dataset["test"] = split_dataset["test"].select(np.arange(100))
        return split_dataset


def torch_sampl_lm_collate_fn(pad_id, max_seq_len, original_batch):
    pad_lens = [max(0, max_seq_len - len(x["input_ids"])) for x in original_batch]
    input_ids = torch.stack(
        [
            torch.tensor(x["input_ids"][:max_seq_len] + [pad_id] * pad_lens[i])
            for i, x in enumerate(original_batch)
        ]
    )
    # attention_mask = torch.stack(
    #     [
    #         torch.tensor(x["attention_mask"][:max_seq_len] + [0] * pad_lens[i])
    #         for i, x in enumerate(original_batch)
    #     ]
    # )

    labels = input_ids.clone()
    labels[labels == pad_id] = -100
    return BatchEncoding(
        {
            "input_ids": input_ids,
            "labels": labels,
        }  # , "attention_mask": attention_mask}
    )


def np_sampl_lm_collate_fn(pad_id, max_seq_len, original_batch):
    pad_lens = [max(0, max_seq_len - len(x["input_ids"])) for x in original_batch]
    input_ids = np.array(
        [
            x["input_ids"][:max_seq_len] + [pad_id] * pad_lens[i]
            for i, x in enumerate(original_batch)
        ]
    )
    labels = np.copy(input_ids)
    labels = np.where(labels == pad_id, -100, labels)
    return BatchEncoding({"input_ids": input_ids, "labels": labels})


class OpenWebTextDP2(DataProcessing):
    """
    OpenweData processor with dynamic padding, no sequence concatention
    """

    def __init__(self, tokenizer, num_load_procs, cache_dir, max_seq_len=1024) -> None:
        self.cache_dir = cache_dir
        self.num_load_procs = num_load_procs
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len

    def get_collate_fn(self, return_type="torch"):
        if return_type == "torch":
            return partial(
                torch_sampl_lm_collate_fn, self.tokenizer.pad_token_id, self.max_seq_len
            )
        else:
            return partial(
                np_sampl_lm_collate_fn, self.tokenizer.pad_token_id, self.max_seq_len
            )

    def tokenize(self, data, add_special_tokens=True):
        """
        Rows are independent, so we can sample of max_len size
            in the datacollator to reduce padding.
        Args:
            data: DatsetDict
        """
        return data.map(
            lambda x: self.tokenizer(
                x["text"],
                add_special_tokens=add_special_tokens,
                return_attention_mask=True,
            ),
            remove_columns=["text"],
            batched=True,
            batch_size=10000,
            num_proc=self.num_load_procs,
            cache_file_names={
                "train": os.path.join(self.cache_dir, "tok_openwebtxt_train.bin"),
                "validation": os.path.join(self.cache_dir, "tok_openwebtxt_val.bin"),
            },
        )

    def get_raw_data(self):
        # takes 54GB in huggingface .cache dir, about 8M documents (8,013,769)
        os.makedirs(self.cache_dir, exist_ok=True)
        dataset = load_dataset(
            "openwebtext", cache_dir=self.cache_dir, num_proc=self.num_load_procs
        )
        # original data only has train
        split_dataset = dataset["train"].train_test_split(
            test_size=0.0005, seed=2357, shuffle=True
        )
        split_dataset["validation"] = split_dataset.pop(
            "test"
        )  # rename the test split to val

        return split_dataset


########### testing ############
# TODO: move to tests
def test3():
    from datasets import disable_caching
    from transformers import AutoTokenizer
    from torch.utils.data import DataLoader

    tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")
    data_dir = "/data_user/data/input"
    cache_dir = os.path.join(data_dir, "openweb")

    dp = OpenWeb(tokenizer=tokenizer, cache_dir=cache_dir, num_load_procs=8)
    dataset = dp.get_raw_data()  # dp.get_raw_data_test()
    print(dataset)
    dataset = dp.tokenize(dataset, max_seq_len=1024)
    # dataset = dataset.with_format("np")
    print(dataset)
    print(dataset["train"][0])
    print(len(dataset["train"][0]["input_ids"]))

    dataloader = DataLoader(
        dataset["train"],
        collate_fn=DataCollatorForLanguageModeling(
            tokenizer, mlm=False, return_tensors="np"
        ),
    )
    it = iter(dataloader)
    print(next(it))


def save():
    from datasets import disable_caching
    from transformers import AutoTokenizer
    from torch.utils.data import DataLoader

    tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")
    data_dir = "/data_user/data/input"
    cache_dir = os.path.join(data_dir, "openweb")

    dp = OpenWeb(tokenizer=tokenizer, cache_dir=cache_dir, num_load_procs=8)
    dataset = dp.get_raw_data()  # dp.get_raw_data_test()
    print(dataset)
    dataset = dp.tokenize(dataset, max_seq_len=1024)

    dataset.save_to_disk(os.path.join(data_dir, "tok_openweb"))


if __name__ == "__main__":
    # test3()
    save()

# latte_trans.preproc.pile
