from functools import partial
import os
import numpy as np
from transformers import DataCollatorForLanguageModeling, BatchEncoding
from latte_trans.preproc.base import DataProcessing
from datasets import Dataset, DatasetDict, load_dataset, concatenate_datasets


def np_collate_fn(original_batch):
    input_ids = np.stack([x["input_ids"] for x in original_batch])
    labels = np.copy(input_ids)
    # pad_mask = None
    # if "attention_mask" in original_batch[0]:
    #     pad_mask = np.stack([x["pad_mask"] for x in original_batch])
    #     return BatchEncoding(
    #         {"input_ids": input_ids, "labels": labels, "pad_mask": pad_mask}
    #     )
    return BatchEncoding({"input_ids": input_ids, "labels": labels})


class PileTok(DataProcessing):

    def __init__(self, tokenizer, cache_dir, num_load_procs, mlm_prob=0.15) -> None:
        self._cache_dir = cache_dir
        self._tokenizer = tokenizer
        self._num_load_procs = num_load_procs
        self._mlm_prob = mlm_prob

    @property
    def tokenizer(self):
        return self._tokenizer

    def get_collate_fn(self, return_type="torch", **kwargs):
        if return_type == "torch":
            return DataCollatorForLanguageModeling(
                tokenizer=self._tokenizer,
                mlm=False,
                return_tensors="pt",
            )
        else:
            # return DataCollatorForLanguageModeling(
            #     tokenizer=self._tokenizer,
            #     mlm=False,
            #     return_tensors="np",
            # )
            return np_collate_fn

    def _tokenize(self, max_seq_len, elements):
        elements = self._tokenizer(
            elements["text"],
            return_special_tokens_mask=False,
            add_special_tokens=True,
            truncation=False,
        )
        # group elements together to reduce padding
        features = list(elements.keys())
        trailing_sequence = {k: [] for k in features}
        new_elements = {k: [] for k in features}
        for i in range(len(elements[features[0]])):
            trail_len = len(trailing_sequence[features[0]])
            if trail_len >= max_seq_len:
                tmp = {k: trailing_sequence[k][0:max_seq_len] for k in features}
                for k in features:
                    new_elements[k].append(tmp[k])
                # reset trailing features
                trailing_sequence = {k: elements[k][i][max_seq_len:] for k in features}
            for k in features:
                trailing_sequence[k].extend(elements[k][i])

        return new_elements

    def tokenize(self, raw_dataset, max_seq_len):
        # pre-tokenize val data
        tok_data_val = raw_dataset["validation"].map(
            partial(self._tokenize, max_seq_len),
            batched=True,
            num_proc=self._num_load_procs,
            batch_size=10000,
            remove_columns=["text", "meta"],
            # cache_file_name= os.path.join(self._cache_dir, "validation.arrow")
        )
        # tokenize on the fly the stram data
        tok_data_train = raw_dataset["train"].map(
            partial(self._tokenize, max_seq_len),
            batched=True,
            batch_size=10000,
            remove_columns=["text", "meta"],
        )

        def gen_from_iterable_dataset(iterable_ds):
            yield from iterable_ds

        tok_data_train = Dataset.from_generator(
            partial(gen_from_iterable_dataset, tok_data_train),
            cache_dir=os.path.join(self._cache_dir, "proc"),
        )
        # tok_data_train.with_format("np").save_to_disk(dataset_dict_path=train_path)
        # tok_data = self.group_examples(tok_data, max_seq_len=max_seq_len)
        # data = DatasetDict({"validation": tok_data_val.with_format("np"), "train":tok_data_train.with_format("np")})
        # data.save_to_disk(dataset_dict_path=path, num_proc=self._num_load_procs)
        # return DatasetDict.load_from_disk(path)
        return DatasetDict(
            {
                "validation": tok_data_val.with_format("np"),
                "train": tok_data_train.with_format("np"),
            }
        )

    def group_examples(self, tok_data, max_seq_len):
        """
        Concatenate all the examples and split by max_seq_len to reduce the number
        of padded tokens.
        Args:
            tok_data: DatsetDict
        """
        features = tok_data["train"].column_names

        def gen(split):
            trailing_sequence = {k: [] for k in features}
            for example in tok_data[split]:
                trail_len = len(trailing_sequence[features[0]])
                if trail_len >= max_seq_len:
                    yield {k: trailing_sequence[k][0:max_seq_len] for k in features}
                    trailing_sequence = {k: example[k][max_seq_len:] for k in features}
                for k in features:
                    trailing_sequence[k].extend(example[k])

        data = DatasetDict({k: None for k in tok_data.keys()})
        for split in tok_data.keys():
            data[split] = Dataset.from_generator(
                partial(gen, split),
                cache_dir=os.path.join(self._cache_dir, split),
                num_proc=self._num_load_procs,
            )

        return data

    def get_raw_data(self):
        train_datasets = load_dataset(
            "monology/pile-uncopyrighted",
            split="train",
            streaming=True,
            cache_dir=self._cache_dir,
            trust_remote_code=True,
        )
        validation = load_dataset(
            "monology/pile-uncopyrighted",
            split="train",
            streaming=False,
            cache_dir=self._cache_dir,
            data_files="val.jsonl.zst",
            trust_remote_code=True,
        )
        # DatasetDict cannot be a merge of Iterable and Data
        return {"train": train_datasets, "validation": validation}


class PileTok2(DataProcessing):

    def __init__(self, tokenizer, cache_dir, num_load_procs, max_seq_len=1024) -> None:
        self._cache_dir = cache_dir
        self._tokenizer = tokenizer
        self._num_load_procs = num_load_procs
        self.max_seq_len = max_seq_len

    def get_collate_fn(self, return_type="torch", **kwargs):
        if return_type == "torch":
            return DataCollatorForLanguageModeling(
                tokenizer=self._tokenizer,
                mlm=False,
                return_tensors="pt",
            )
        else:
            # return DataCollatorForLanguageModeling(
            #     tokenizer=self._tokenizer,
            #     mlm=False,
            #     return_tensors="np",
            # )
            return np_collate_fn

    @property
    def tokenizer(self):
        return self._tokenizer

    def _tokenize(self, elements):
        max_seq_len = self.max_seq_len
        elements = self._tokenizer(
            elements["text"],
            return_special_tokens_mask=False,
            add_special_tokens=True,
            truncation=False,
        )
        # group elements together to reduce padding
        features = list(elements.keys())
        trailing_sequence = {k: [] for k in features}
        new_elements = {k: [] for k in features}
        for i in range(len(elements[features[0]])):
            trail_len = len(trailing_sequence[features[0]])
            if trail_len >= max_seq_len:
                tmp = {k: trailing_sequence[k][0:max_seq_len] for k in features}
                for k in features:
                    new_elements[k].append(tmp[k])
                # reset trailing features
                trailing_sequence = {k: elements[k][i][max_seq_len:] for k in features}
            for k in features:
                trailing_sequence[k].extend(elements[k][i])

        return new_elements

    def tokenize(self, raw_dataset, force_tok=False):
        # pre-tokenize val data
        tok_data_val = raw_dataset.map(
            partial(self._tokenize),
            batched=True,
            num_proc=self._num_load_procs,
            batch_size=10000,
            remove_columns=["text", "meta"],
            cache_file_names={
                "train": os.path.join(self._cache_dir, "tok_pile_train.bin"),
                "validation": os.path.join(self._cache_dir, "tok_pile_val.bin"),
            },
            load_from_cache_file=not force_tok,
        )
        return tok_data_val

    def get_raw_data(self):
        train_datasets = load_dataset(
            "monology/pile-uncopyrighted",
            split="train",
            streaming=False,
            cache_dir=self._cache_dir,
            trust_remote_code=True,
        )
        validation = load_dataset(
            "monology/pile-uncopyrighted",
            split="validation",
            streaming=False,
            cache_dir=self._cache_dir,
            data_files="val.jsonl.zst",
            trust_remote_code=True,
        )
        # DatasetDict cannot be a merge of Iterable and Data
        return DatasetDict({"train": train_datasets, "validation": validation})

    def get_raw_data_test(self):
        """Load only the validation/test files for evaluation"""
        test = load_dataset(
            "monology/pile-uncopyrighted",
            streaming=False,
            cache_dir=self._cache_dir,
            data_files="test.jsonl.zst",
            trust_remote_code=True,
        )
        validation = load_dataset(
            "monology/pile-uncopyrighted",
            streaming=False,
            cache_dir=self._cache_dir,
            data_files="val.jsonl.zst",
            trust_remote_code=True,
        )
        # DatasetDict cannot be a merge of Iterable and Data
        return DatasetDict({"train": test["train"], "validation": validation["train"]})


class PileStream(DataProcessing):

    def __init__(self, tokenizer, cache_dir, num_load_procs, mlm_prob=0.15) -> None:
        self._cache_dir = cache_dir
        self._tokenizer = tokenizer
        self._num_load_procs = num_load_procs
        self._mlm_prob = mlm_prob

    def get_collate_fn(self, return_type="torch", **kwargs):
        if return_type == "torch":
            return DataCollatorForLanguageModeling(
                tokenizer=self._tokenizer,
                mlm=False,
                return_tensors="pt",
            )
        else:
            # return DataCollatorForLanguageModeling(
            #     tokenizer=self._tokenizer,
            #     mlm=False,
            #     return_tensors="np",
            # )
            return np_collate_fn

    def _tokenize(self, max_seq_len, elements):
        elements = self._tokenizer(
            elements["text"],
            return_special_tokens_mask=False,
            add_special_tokens=True,
            truncation=False,
        )
        # group elements together to reduce padding
        features = list(elements.keys())
        trailing_sequence = {k: [] for k in features}
        new_elements = {k: [] for k in features}
        for i in range(len(elements[features[0]])):
            trail_len = len(trailing_sequence[features[0]])
            if trail_len >= max_seq_len:
                tmp = {k: trailing_sequence[k][0:max_seq_len] for k in features}
                for k in features:
                    new_elements[k].append(tmp[k])
                # reset trailing features
                trailing_sequence = {k: elements[k][i][max_seq_len:] for k in features}
            for k in features:
                trailing_sequence[k].extend(elements[k][i])

        return new_elements

    def tokenize(self, raw_dataset, max_seq_len):
        path = os.path.join(self._cache_dir, "tokenized")
        # pre-tokenize val data
        tok_data_val = raw_dataset["validation"].map(
            partial(self._tokenize, max_seq_len),
            batched=True,
            num_proc=self._num_load_procs,
            batch_size=10000,
            remove_columns=["text", "meta"],
        )
        # tokenize on the fly the stram data
        tok_data_train = raw_dataset["train"].map(
            partial(self._tokenize, max_seq_len),
            batched=True,
            batch_size=10000,
            remove_columns=["text", "meta"],
        )
        return {
            "train": tok_data_train.with_format("np"),
            "validation": tok_data_val.with_format("np"),
        }

    def get_raw_data(self):
        train_datasets = load_dataset(
            "monology/pile-uncopyrighted",
            split="train",
            streaming=True,
            cache_dir=self._cache_dir,
            trust_remote_code=True,
            num_proc=self._num_load_procs,
        )
        validation = load_dataset(
            "monology/pile-uncopyrighted",
            split="train",
            streaming=False,
            cache_dir=self._cache_dir,
            data_files="val.jsonl.zst",
            trust_remote_code=True,
            num_proc=self._num_load_procs,
        )
        # DatasetDict cannot be a merge of Iterable and Data
        return {"train": train_datasets, "validation": validation}

    def get_raw_data2(self):
        train_datasets = load_dataset(
            "EleutherAI/pile",
            split="train",
            streaming=True,
            cache_dir=self._cache_dir,
            trust_remote_code=True,
        )
        validation = load_dataset(
            "EleutherAI/pile",
            split="train",
            streaming=False,
            cache_dir=self._cache_dir,
            data_files="val.jsonl.zst",
            trust_remote_code=True,
        )
        # DatasetDict cannot be a merge of Iterable and Data
        return {"train": train_datasets, "validation": validation}


########### testing ############
# TODO: move to tests


def test():
    from datasets import disable_caching
    from transformers import AutoTokenizer

    tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")
    disable_caching()
    data_dir = "/data_user/data/input"
    cache_dir = os.path.join(data_dir, "pile")
    import multiprocessing

    num_proc = multiprocessing.cpu_count()
    print(f"Num processes is: {num_proc}")
    dp = PilePretok(tokenizer=None, cache_dir=cache_dir, num_load_procs=num_proc)
    raw_data = dp.get_raw_data()
    print(raw_data)
    raw_data = dp.tokenize(raw_data, max_seq_len=1024)
    print(raw_data)
    updated_dataset = raw_data.map(lambda x: x)


def test2():
    from datasets import disable_caching
    from transformers import AutoTokenizer
    from torch.utils.data import DataLoader

    tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")
    disable_caching()
    data_dir = "/data_user/data/input"
    cache_dir = os.path.join(data_dir, "pile")

    dp = PileStream(tokenizer=tokenizer, cache_dir=cache_dir, num_load_procs=2)
    dataset = dp.get_raw_data()
    print(dataset)
    dataset = dp.tokenize(dataset, max_seq_len=1024)
    dataset = dataset.with_format("np")
    dataloader = DataLoader(
        dataset,
        collate_fn=DataCollatorForLanguageModeling(
            tokenizer, mlm=False, return_tensors="np"
        ),
    )
    it = iter(dataloader)
    print(next(it))
    for i, batch in enumerate(dataloader):
        if i > 100:
            break
        print(i, batch)


def test3():
    from datasets import disable_caching
    from transformers import AutoTokenizer
    from torch.utils.data import DataLoader

    tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")
    data_dir = "/data_user/data/input"
    cache_dir = os.path.join(data_dir, "pile3")

    dp = PileTok2(tokenizer=tokenizer, cache_dir=cache_dir, num_load_procs=8)
    dataset = dp.get_raw_data()  # dp.get_raw_data_test()
    print(dataset)
    dataset = dp.tokenize(dataset, max_seq_len=1024)
    # dataset = dataset.with_format("np")
    print(dataset)
    print(dataset["train"][0])
    print(len(dataset["train"][0]["input_ids"]))

    dataloader = DataLoader(
        dataset["train"],
        collate_fn=DataCollatorForLanguageModeling(
            tokenizer, mlm=False, return_tensors="np"
        ),
    )
    it = iter(dataloader)
    print(next(it))


if __name__ == "__main__":
    test3()

# latte_trans.preproc.pile
