import abc
import numpy as np
from jax import numpy as jnp
import torch
from transformers import BatchEncoding


def torch_collate_fn(original_batch):
    input_ids = torch.tensor([x["input_ids"] for x in original_batch])
    labels = torch.tensor([x["labels"] for x in original_batch])
    pad_mask = None
    if "pad_mask" in original_batch[0]:
        pad_mask = torch.tensor([x["pad_mask"] for x in original_batch])
        return BatchEncoding(
            {"input_ids": input_ids, "labels": labels, "pad_mask": pad_mask}
        )

    return BatchEncoding({"input_ids": input_ids, "labels": labels})


def np_collate_fn(original_batch):
    input_ids = np.array([x["input_ids"] for x in original_batch])
    labels = np.array([x["labels"] for x in original_batch])

    pad_mask = None
    if "pad_mask" in original_batch[0]:
        pad_mask = np.array([x["pad_mask"] for x in original_batch])
        return BatchEncoding(
            {"input_ids": input_ids, "labels": labels, "pad_mask": pad_mask}
        )
    return BatchEncoding({"input_ids": input_ids, "labels": labels})


class Tokenizer(abc.ABC):
    def __init__(self) -> None:
        self.special_tokens = {
            "<unk>": "<unk>",
            "<pad>": "<pad>",
            "<bos>": "<bos>",
            "<eos>": "<eos>",
        }
        self.mapping = {
            tok: id_
            for tok, id_ in zip(range(len(self.special_tokens)), self.special_tokens)
        }

    @abc.abstractmethod
    def __call__(self):
        pass

    @property
    def pad_token_id(self):
        return self.mapping["<pad>"]

    @property
    def unk_token_id(self):
        return self.mapping["<unk>"]

    @property
    def bos_token_id(self):
        return self.mapping["<bos>"]

    @property
    def eos_token_id(self):
        return self.mapping["<eos>"]

    @property
    def vocab_size(self):
        return len(self.mapping)


class DataProcessing:
    def get_collate_fn(self, return_type="torch", **kwargs):
        if return_type == "torch":
            return torch_collate_fn
        else:
            return np_collate_fn

    @staticmethod
    def subsample(dataset, nr_samples):
        for k in dataset.keys():
            nr_samples = min(nr_samples, len(dataset[k]))
            tok_data = dataset[k].select(torch.arange(nr_samples).tolist())
            dataset[k] = tok_data
        return dataset

    @staticmethod
    def clean_colnames(tokenized_data, keep_col_names):
        """
        Args:
            tokenized_data: [Dataset, DictDataset]
            keep_col_names: List[str]
                Column names to keep
        """
        col_names = tokenized_data.column_names
        all_coll_names = set()
        for k in col_names.keys():
            all_coll_names.update(col_names[k])
        to_remove = all_coll_names - set(keep_col_names)
        return tokenized_data.remove_columns(list(to_remove))

    @abc.abstractmethod
    def tokenize(self):
        pass

    @abc.abstractmethod
    def get_raw_data(self):
        pass
