from functools import partial
from typing import Dict, Iterable
import os
from os import PathLike
from pathlib import Path
from tqdm import tqdm
import torch
from numpy.typing import NDArray
import numpy as np
from datasets import load_dataset, Dataset
from transformers import BatchEncoding
from datasets import IterableDatasetDict, DatasetDict
from latte_trans.preproc.base import DataProcessing
from latte_trans.preproc.toks import SpecialToksGPT2TokenizerFast
from datasets import load_dataset, DownloadConfig


def np_collate_fn(max_seq_len, original_batch):
    input_ids = np.stack([x["input_ids"][:max_seq_len] for x in original_batch])
    labels = np.copy(input_ids)
    return BatchEncoding({"input_ids": input_ids, "labels": labels})


def torch_collate_fn(max_seq_len, original_batch):
    input_ids = torch.stack([x["input_ids"][:max_seq_len] for x in original_batch])
    labels = torch.clone(input_ids)

    if "attention_mask" in original_batch[0]:
        attention_mask = torch.stack(
            [x["attention_mask"][:max_seq_len] for x in original_batch]
        )
        return BatchEncoding(
            {"input_ids": input_ids, "labels": labels, "attention_mask": attention_mask}
        )
    return BatchEncoding({"input_ids": input_ids, "labels": labels})


def gen_from_iterable_dataset(iterable_ds):
    yield from iterable_ds


class SlimPajama(DataProcessing):
    """Group examples such that padding is reduced to minimal"""

    def __init__(self, tokenizer, cache_dir, num_load_procs, max_seq_len) -> None:
        self._cache_dir = cache_dir
        self._tokenizer = tokenizer
        self._num_load_procs = num_load_procs
        self._max_seq_len = max_seq_len

    @property
    def tokenizer(self):
        return self._tokenizer

    def get_collate_fn(self, return_type="torch", **kwargs):
        if return_type == "torch":
            return partial(torch_collate_fn, self._max_seq_len)
        else:
            return partial(np_collate_fn, self._max_seq_len)

    def _tokenize(self, max_seq_len, elements):
        elements = self._tokenizer(
            elements["text"],
            return_special_tokens_mask=False,
            add_special_tokens=True,
            truncation=False,
        )
        # only keep input_ids because of size
        elements = {"input_ids": elements["input_ids"]}
        # group elements together to reduce padding
        features = list(elements.keys())
        trailing_sequence = {k: [] for k in features}
        new_elements = {k: [] for k in features}
        for i in range(len(elements[features[0]])):
            trail_len = len(trailing_sequence[features[0]])
            if trail_len >= max_seq_len:
                tmp = {k: trailing_sequence[k][0:max_seq_len] for k in features}
                for k in features:
                    new_elements[k].append(tmp[k])
                # reset trailing features
                trailing_sequence = {k: elements[k][i][max_seq_len:] for k in features}
            for k in features:
                trailing_sequence[k].extend(elements[k][i])

        return new_elements

    def tokenize(self, raw_dataset: DatasetDict, force_tok=False) -> DatasetDict:
        # data_path = os.path.join(self._cache_dir, "tokenized")
        # if os.path.exists(data_path) and not force_tok:
        #     return DatasetDict.load_from_disk(data_path)

        print(raw_dataset.column_names)
        # pre-tokenize val data
        tok_data = raw_dataset.map(
            partial(self._tokenize, self._max_seq_len),
            batched=True,
            batch_size=32000,
            remove_columns=raw_dataset["train"].column_names,
            cache_file_names={
                "train": os.path.join(self._cache_dir, "tok_pajama_train.bin"),
                "test": os.path.join(self._cache_dir, "tok_pajama_test.bin"),
                "validation": os.path.join(self._cache_dir, "tok_pajama_val.bin"),
            },
            num_proc=self._num_load_procs,
            load_from_cache_file=not force_tok,
        )

        # data = DatasetDict(data)
        # tok_data.save_to_disk(data_path, num_proc=self._num_load_procs)

        return tok_data

    def get_raw_data(self):
        os.makedirs(self._cache_dir, exist_ok=True)
        dataset = load_dataset(
            "cerebras/SlimPajama-627B",  # "roneneldan/TinyStories",  #
            cache_dir=self._cache_dir,
            num_proc=self._num_load_procs,
            download_config=DownloadConfig(resume_download=True),
            # split="train[:2%]",
            # streaming=True,
        )
        # assert dataset.keys() == {"train", "test", "validation"}

        # dataset["train"] = dataset["train"].select(np.arange(10240))
        # dataset["validation"] = dataset["train"].select(np.arange(1024))
        return dataset


def preprocess_data():

    # base_dir = "/raw_data_user/data"
    # cache_dir = os.path.join(base_dir, "input", "slim_pajama")
    # dp = SlimPajama(
    #     tokenizer=None,
    #     cache_dir=cache_dir,
    #     max_seq_len=None,
    #     num_load_procs=8,  # max(1, os.cpu_count() - 20),
    # )
    # raw_data = dp.get_raw_data()
    # raw_data = load_dataset(
    #     "/raw_data_user/data/input/slim-pajama/datasets--cerebras--SlimPajama-627B",  # "cerebras/SlimPajama-627B",
    #     cache_dir="/raw_data_user/data/input/slim-pajama/datasets--cerebras--SlimPajama-627B",
    #     num_proc=dp._num_load_procs,  # self._num_load_procs,
    # )

    # from huggingface_hub import snapshot_download

    #
    """
    # 1. download data with huggingface cli1 becuase it is faster
    huggingface-cli download cerebras/SlimPajama-627B  --repo-type dataset --local-dir ./slim_pajama
    # 2. Load with hugginface
    import datasets; datasets.load_dataset("./d1")

    """
    ds = load_dataset(
        "/raw_data_user/data/input/slim_pajama",
        cache_dir="/user_back_data/data/input/cache_pajama",
        num_proc=16,
    )
    print(ds)
    ds.save_to_disk(
        "/user_all_data/data/input/pajama_raw",
        num_proc=16,
    )


def tokenize():
    from transformers import AutoTokenizer

    # ds = load_dataset(
    #     "/user_all_data/data/input/pajama_raw",
    #     cache_dir="/user_all_data/data/input/pajama_raw",
    #     num_proc=16,
    #     split="train[10:20]",
    # )

    base_dir = "/user_all_data/data/"
    MAX_SEQ_LEN = 32000
    # tokenizer = SpecialToksGPT2TokenizerFast.from_pretrained("gpt2")
    data_cache_dir = Path(base_dir) / "input/pajama_raw/cache"
    tokenizer = AutoTokenizer.from_pretrained(
        "google/gemma-2-2b",
        cache_dir=Path(base_dir) / "input/cache_hugg",
        truncation_side="right",
        padding_side="right",
    )
    dp = SlimPajama(
        tokenizer=tokenizer,
        cache_dir=data_cache_dir,
        max_seq_len=MAX_SEQ_LEN,
        num_load_procs=16,
    )
    data_path = "/user_all_data/data/input/pajama_raw"
    data = DatasetDict.load_from_disk(data_path)
    tok_data = dp.tokenize(data)


def test():
    # pdm run python3 -m latte_trans.preproc.slim_pajama
    # tokenizer = SpecialToksGPT2TokenizerFast.from_pretrained("EleutherAI/gpt-neo-125M")
    from transformers import AutoTokenizer

    base_dir = "/data_user/data"
    MAX_SEQ_LEN = 1024
    cache_dir = os.path.join(base_dir, "input", "slim_pajama")
    # tokenizer = SpecialToksGPT2TokenizerFast.from_pretrained("gpt2")
    tokenizer = AutoTokenizer.from_pretrained(
        "google/gemma-2-2b",
        cache_dir=Path(base_dir) / "input/cache_hugg",
        truncation_side="right",
        padding_side="right",
    )

    print(
        "Tokenizer bos: ",
        tokenizer.bos_token_id,
        tokenizer.pad_token_id,
        tokenizer.eos_token_id,
        tokenizer.unk_token_id,
        tokenizer.mask_token_id,
    )

    dp = SlimPajama(
        tokenizer=tokenizer,
        cache_dir=cache_dir,
        max_seq_len=MAX_SEQ_LEN,
        num_load_procs=max(1, os.cpu_count() - 20),
    )
    raw_data = dp.get_raw_data()
    print("Raw data is", raw_data)

    dataset = dp.tokenize(raw_data)
    num_workers = 0
    from torch.utils.data import DataLoader

    # import multiprocessing
    dl = DataLoader(
        dataset=dataset["train"],
        batch_size=10,
        shuffle=True,
        collate_fn=dp.get_collate_fn(return_type="np"),
    )
    print(next(iter(dl)))


if __name__ == "__main__":
    # pdm run python3 -m latte_trans.preproc.slim_pajama
    # preprocess_data()
    tokenize()


# from transformers import AutoTokenizer
# from datasets import DatasetDict
# import os
# from pathlib import Path
# import numpy as np
# from functools import partial

# base_dir = "/user_all_data/data/"
# MAX_SEQ_LEN = 32000
# # tokenizer = SpecialToksGPT2TokenizerFast.from_pretrained("gpt2")
# cache_dir = Path(base_dir) / "input/pajama_raw/cache"
# tokenizer = AutoTokenizer.from_pretrained(
#     "google/gemma-2-2b",
#     cache_dir=Path(base_dir) / "input/cache_hugg",
#     truncation_side="right",
#     padding_side="right",
# )


# def _tokenize(tokenizer, max_seq_len, elements):
#     elements = tokenizer(
#         elements["text"],
#         return_special_tokens_mask=False,
#         add_special_tokens=True,
#         truncation=False,
#     )
#     # only keep input_ids because of size
#     elements = {"input_ids": elements["input_ids"]}
#     # group elements together to reduce padding
#     features = list(elements.keys())
#     trailing_sequence = {k: [] for k in features}
#     new_elements = {k: [] for k in features}
#     for i in range(len(elements[features[0]])):
#         trail_len = len(trailing_sequence[features[0]])
#         if trail_len >= max_seq_len:
#             tmp = {k: trailing_sequence[k][0:max_seq_len] for k in features}
#             for k in features:
#                 new_elements[k].append(tmp[k])
#             # reset trailing features
#             trailing_sequence = {k: elements[k][i][max_seq_len:] for k in features}
#         for k in features:
#             trailing_sequence[k].extend(elements[k][i])
#     return new_elements


# # ds = {k: ds[k].select(100000) for k in ds.keys()}
# data_path = Path(base_dir) / "input/pajama_raw/"
# ds = DatasetDict.load_from_disk(data_path)
# tok_data = ds.map(
#     partial(_tokenize, tokenizer, MAX_SEQ_LEN),
#     batched=True,
#     batch_size=10000,
#     remove_columns=ds["train"].column_names,
#     cache_file_names={
#         "train": os.path.join(cache_dir, "tok_pajama_train.bin"),
#         "test": os.path.join(cache_dir, "tok_pajama_test.bin"),
#         "validation": os.path.join(cache_dir, "tok_pajama_val.bin"),
#     },
#     num_proc=16,
# )
