import torch
from transformers import PreTrainedTokenizer
from transformers.tokenization_utils_base import BatchEncoding
import sentencepiece as spm
import typing

class TokenizerWrapper(PreTrainedTokenizer):
    def __init__(
        self,
        tokenizer_checkpoint : str
    ) -> None:
        self.tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_checkpoint)

    @property
    def pad_token_id(
        self
    ) -> int:
        return self.tokenizer.pad_id()

    def save_pretrained(
        self,
        *args
    ) -> None:
        pass

    def pad(
        self,
        data: typing.List[torch.Tensor],
        **kwargs
    ) -> torch.Tensor:

        return BatchEncoding(
            {
                'input_ids' : torch.nn.utils.rnn.pad_sequence(
                    sequences=[torch.tensor(cur['input_ids'], dtype=torch.int64) for cur in data],
                    batch_first=True,
                    padding_value=self.tokenizer.pad_id()
                ),
                'attention_mask' : torch.nn.utils.rnn.pad_sequence(
                    sequences=[torch.tensor(cur['attention_mask'], dtype=torch.int64) for cur in data],
                    batch_first=True,
                    padding_value=self.tokenizer.pad_id()
                )
            }
        )

    def encode(
        self,
        data: typing.Union[str, typing.List[str]],
        add_special_tokens: bool = True,
        max_length: int = 256,
    ) -> typing.Dict[str, torch.Tensor]:

        if isinstance(data, str):
            data = [data]

        if add_special_tokens:
            data = self.tokenizer.encode(data, add_bos=True, add_eos=True)
        else:
            data = self.tokenizer.encode(data)

        input_ids = torch.nn.utils.rnn.pad_sequence(
            sequences=[torch.tensor(cur, dtype=torch.int64) for cur in data],
            batch_first=True,
            padding_value=self.tokenizer.pad_id()
        )[:, :max_length]
        attention_mask = (input_ids != self.tokenizer.pad_id()).to(torch.int64)

        return BatchEncoding(
            {
                'input_ids' : input_ids,
                'attention_mask' : attention_mask
            }
        )

    def __call__(
        self,
        data: typing.Union[str, typing.List[str]],
        **kwargs
    ) -> torch.Tensor:
        return self.encode(data)

    def decode(
        self,
        data: torch.Tensor
    ) -> typing.Union[typing.List[str], str]:
        if len(data.shape) == 1:
            return self.tokenizer.decode(data.numpy().tolist())
        if len(data.shape) == 2:
            return [
                self.tokenizer.decode(data[idx, :].numpy().tolist()) for idx in range(data.shape[0])
            ]
        raise ValueError("Expected 1-dim or 2-dim data")
