from functools import partial
import os
from transformers import BatchEncoding
import torch
import numpy as np
from latte_trans.preproc.base import DataProcessing
from datasets import load_dataset, DatasetDict


def torch_sampl_lm_collate_fn(pad_id, max_seq_len, original_batch):
    pad_lens = [max(0, max_seq_len - len(x["input_ids"])) for x in original_batch]
    input_ids = torch.tensor(
        [
            x["input_ids"][:max_seq_len] + [pad_id] * pad_lens[i]
            for i, x in enumerate(original_batch)
        ]
    )

    labels = input_ids.clone().detach()
    labels[labels == pad_id] = -100
    return BatchEncoding({"input_ids": input_ids, "labels": labels})


def np_sampl_lm_collate_fn(pad_id, max_seq_len, original_batch):
    pad_lens = [max(0, max_seq_len - len(x["input_ids"])) for x in original_batch]
    input_ids = [
        x["input_ids"][:max_seq_len] + [pad_id] * pad_lens[i]
        for i, x in enumerate(original_batch)
    ]
    input_ids = np.stack(input_ids, axis=0)
    labels = np.copy(input_ids)
    labels = np.where(labels == pad_id, -100, labels)
    return BatchEncoding({"input_ids": input_ids, "labels": labels})


class Wiki103DP(DataProcessing):
    """
    Process raw data as Dataset
    """

    def __init__(
        self, tokenizer, cache_dir, num_load_procs=7, char_level=False, max_seq_len=1024
    ):
        self.cache_dir = cache_dir
        self.num_load_procs = num_load_procs
        self.tokenizer = tokenizer
        self.char_level = char_level
        self.max_seq_len = max_seq_len

    def get_collate_fn(self, return_type="torch"):
        if return_type == "torch":
            return partial(
                torch_sampl_lm_collate_fn, self.tokenizer.pad_token_id, self.max_seq_len
            )
        else:
            return partial(
                np_sampl_lm_collate_fn, self.tokenizer.pad_token_id, self.max_seq_len
            )

    def tokenize(self, data: DatasetDict, add_special_tokens=False, force_tok=False):
        """
        Rows are independent, so we can sample of max_len size
            in the datacollator to reduce padding.
        Args:
            data: DatsetDict
        """
        return data.map(
            lambda x: self.tokenizer(
                x["text"],
                add_special_tokens=add_special_tokens,
                return_attention_mask=False,
            ),
            remove_columns=["text"],
            batched=True,
            batch_size=10000,
            num_proc=self.num_load_procs,
            cache_file_names={
                "train": os.path.join(self.cache_dir, "tok_wiki103_train.bin"),
                "validation": os.path.join(self.cache_dir, "tok_wiki103_val.bin"),
                "test": os.path.join(self.cache_dir, "tok_wiki103_test.bin"),
            },
            load_from_cache_file=not force_tok,
        )

    def get_raw_data(self):
        os.makedirs(self.cache_dir, exist_ok=True)
        # if self.char_level:
        # name = "wikitext-103-raw-v1"
        # else:
        name = "wikitext-103-v1"
        dataset = load_dataset(
            path="wikitext",
            name=name,
            cache_dir=self.cache_dir,
            num_proc=self.num_load_procs,
        )

        dataset = dataset.filter(lambda example: len(example["text"]) > 0)

        # overfit check
        # dataset['train'] = dataset['train'].select(np.arange(100))
        # dataset["validation"] = dataset["train"]
        # dataset["validation"] = dataset["validation"].select(np.arange(100))
        return dataset
