from functools import partial
import os
import numpy as np
import torch
from transformers import DataCollatorForLanguageModeling, BatchEncoding
from latte_trans.preproc.base import DataProcessing
from datasets import Dataset, DatasetDict, load_dataset, concatenate_datasets


def np_collate_fn(max_seq_len, original_batch):
    input_ids = np.stack([x["input_ids"][:max_seq_len] for x in original_batch])
    labels = np.copy(input_ids)
    return BatchEncoding({"input_ids": input_ids, "labels": labels})


def torch_collate_fn(max_seq_len, original_batch):
    input_ids = torch.stack([x["input_ids"][:max_seq_len] for x in original_batch])
    labels = torch.clone(input_ids)

    if "attention_mask" in original_batch[0]:
        attention_mask = torch.stack(
            [x["attention_mask"][:max_seq_len] for x in original_batch]
        )
        return BatchEncoding(
            {"input_ids": input_ids, "labels": labels, "attention_mask": attention_mask}
        )
    return BatchEncoding({"input_ids": input_ids, "labels": labels})


class BooksDP(DataProcessing):

    def __init__(self, tokenizer, cache_dir, num_load_procs) -> None:
        self._cache_dir = cache_dir
        self._tokenizer = tokenizer
        self._num_load_procs = num_load_procs

    def get_collate_fn(self, return_type="torch", **kwargs):
        if return_type == "torch":
            return DataCollatorForLanguageModeling(
                tokenizer=self._tokenizer,
                mlm=False,
                return_tensors="pt",
            )
        else:
            return np_collate_fn

    def _tokenize(self, max_seq_len, elements):
        elements = self._tokenizer(
            elements["text"],
            return_special_tokens_mask=False,
            add_special_tokens=True,
            truncation=False,
        )
        return elements

    def tokenize(self, raw_dataset, max_seq_len):
        # pre-tokenize val data
        tok_data_val = raw_dataset.map(
            partial(self._tokenize, max_seq_len),
            batched=True,
            num_proc=self._num_load_procs,
            batch_size=10000,
            remove_columns=["text"],
            cache_file_names={
                "train": os.path.join(self._cache_dir, "tok_books_train.bin"),
                "validation": os.path.join(self._cache_dir, "tok_books_val.bin"),
                "test": os.path.join(self._cache_dir, "tok_books_test.bin"),
            },
        )
        return tok_data_val

    def get_raw_data(self):
        dataset = load_dataset(
            "bookcorpus/bookcorpus",
            streaming=False,
            cache_dir=self._cache_dir,
            trust_remote_code=True,
            num_proc=self._num_load_procs,
        )
        return dataset


class BookCorpusLong(DataProcessing):
    """Group examples such that padding is reduced to minimal"""

    def __init__(self, tokenizer, cache_dir, num_load_procs, max_seq_len) -> None:
        self._cache_dir = cache_dir
        self._tokenizer = tokenizer
        self._num_load_procs = num_load_procs
        self._max_seq_len = max_seq_len

    @property
    def tokenizer(self):
        return self._tokenizer

    def get_collate_fn(self, return_type="torch", **kwargs):
        if return_type == "torch":
            return partial(torch_collate_fn, self._max_seq_len)
        else:
            return partial(np_collate_fn, self._max_seq_len)

    def _tokenize(self, max_seq_len, elements):
        elements = self._tokenizer(
            elements["text"],
            return_special_tokens_mask=False,
            add_special_tokens=True,
            truncation=False,
        )
        # only keep input_ids because of size
        elements = {"input_ids": elements["input_ids"]}
        # group elements together to reduce padding
        features = list(elements.keys())
        trailing_sequence = {k: [] for k in features}
        new_elements = {k: [] for k in features}
        for i in range(len(elements[features[0]])):
            trail_len = len(trailing_sequence[features[0]])
            if trail_len >= max_seq_len:
                tmp = {k: trailing_sequence[k][0:max_seq_len] for k in features}
                for k in features:
                    new_elements[k].append(tmp[k])
                # reset trailing features
                trailing_sequence = {k: elements[k][i][max_seq_len:] for k in features}
            for k in features:
                trailing_sequence[k].extend(elements[k][i])

        return new_elements

    def tokenize(self, raw_dataset: DatasetDict, force_tok=False) -> DatasetDict:
        # pre-tokenize val data
        tok_data = raw_dataset.map(
            partial(self._tokenize, self._max_seq_len),
            batched=True,
            batch_size=32000,
            remove_columns=raw_dataset["train"].column_names,
            cache_file_names={
                "train": os.path.join(self._cache_dir, "tok_book_train.bin"),
                "test": os.path.join(self._cache_dir, "tok_book_test.bin"),
                "validation": os.path.join(self._cache_dir, "tok_pajama_val.bin"),
            },
            num_proc=self._num_load_procs,
            load_from_cache_file=not force_tok,
        )
        # original data only has train
        tok_data = tok_data["train"].train_test_split(
            test_size=0.1, seed=2357, shuffle=True
        )
        tok_data["validation"] = tok_data.pop("test")  # rename the test split to val
        return tok_data

    def get_raw_data(self):
        os.makedirs(self._cache_dir, exist_ok=True)
        dataset = load_dataset(
            "bookcorpus/bookcorpus",
            streaming=False,
            cache_dir=self._cache_dir,
            trust_remote_code=True,
            num_proc=self._num_load_procs,
        )
        return dataset


def test_book_corpus_long():
    from pathlib import Path
    from transformers import AutoTokenizer

    base_dir = "/user_all_data/data/"
    MAX_SEQ_LEN = 32000
    # tokenizer = SpecialToksGPT2TokenizerFast.from_pretrained("gpt2")
    data_cache_dir = Path(base_dir) / "input/pajama_raw/cache"
    tokenizer = AutoTokenizer.from_pretrained(
        "google/gemma-2-2b",
        cache_dir=Path(base_dir) / "input/cache_hugg",
        truncation_side="right",
        padding_side="right",
    )
    data_dir = "/data_user/data/input"
    cache_dir = os.path.join(data_dir, "bookcorpus")
    dp = BookCorpusLong(
        tokenizer, cache_dir=cache_dir, num_load_procs=8, max_seq_len=4048
    )

    dataset = dp.get_raw_data()
    print(dataset)
    tok_data = dp.tokenize(dataset)
    print(tok_data)


def test_bookdp():
    from pathlib import Path
    from transformers import AutoTokenizer

    base_dir = "/user_all_data/data/"
    MAX_SEQ_LEN = 32000
    # tokenizer = SpecialToksGPT2TokenizerFast.from_pretrained("gpt2")
    data_cache_dir = Path(base_dir) / "input/pajama_raw/cache"
    tokenizer = AutoTokenizer.from_pretrained(
        "google/gemma-2-2b",
        cache_dir=Path(base_dir) / "input/cache_hugg",
        truncation_side="right",
        padding_side="right",
    )
    data_dir = "/data_user/data/input"
    cache_dir = os.path.join(data_dir, "bookcorpus")
    dp = BooksDP(tokenizer, cache_dir=cache_dir, num_load_procs=8)

    dataset = dp.get_raw_data()
    print(dataset)
    tok_data = dp.tokenize(dataset, max_seq_len=MAX_SEQ_LEN)
    print(tok_data)
    # print(len(tok_data["train"][0]["input_ids"]))
    a = tok_data["train"].map(lambda x: {"len": len(x["input_ids"])}, num_proc=10)
    print(a)
    print(np.mean(a["len"]))
    print(tok_data["train"][0]["input_ids"])
    print(dataset["train"][0]["text"])
    print(dataset["train"][1]["text"])
    print(dataset["train"][2]["text"])
    print(dataset["train"][3]["text"])


if __name__ == "__main__":
    # main()
    test_bookdp()
