from typing import List, Union, Optional
from fast_compression import batch_encode, decode
from transformers import PreTrainedTokenizerBase, AutoTokenizer, BatchEncoding


class LZWTokenizer:
    def __init__(
        self,
        tokenizer: PreTrainedTokenizerBase,
        initial_vocab_size: int,
        max_codebook_size: int,
        max_subtokens: int,
        disabled_ids: Optional[List[int]] = None,
    ) -> None:
        self.tokenizer = tokenizer
        self.initial_vocab_size = initial_vocab_size
        self.max_codebook_size = max_codebook_size
        self.max_subtokens = max_subtokens
        self.disabled_ids = disabled_ids

        self.old_batch_encode_plus = self.tokenizer._batch_encode_plus
        self.tokenizer._batch_encode_plus = self._batch_encode_plus

        self.old_decode = self.tokenizer._decode
        self.tokenizer._decode = self._decode

    def __getattr__(self, attr):
        return getattr(self.tokenizer, attr)

    def __call__(self, *args, **kwargs) -> BatchEncoding:
        return self.tokenizer(*args, **kwargs)

    def _batch_encode_plus(self, *args, **kwargs) -> BatchEncoding:
        # TODO: we don't support padding here
        return_tensors = kwargs.pop("return_tensors", None)

        encoding = self.old_batch_encode_plus(*args, **kwargs)

        encoding["input_ids"], encoding["attention_mask"] = batch_encode(
            encoding["input_ids"],
            initial_vocab_size=self.initial_vocab_size,
            max_codebook_size=self.max_codebook_size,
            max_subtokens=self.max_subtokens,
            disabled_ids=self.disabled_ids,
        )

        if return_tensors:
            encoding = encoding.convert_to_tensors(return_tensors)

        return encoding

    def _decode(
        self,
        token_ids: Union[int, List[int]],
        skip_special_tokens: bool = False,
        clean_up_tokenization_spaces: bool = None,
        **kwargs,
    ) -> str:
        if isinstance(token_ids, int):
            token_ids = [token_ids]

        token_ids = decode(
            token_ids,
            initial_vocab_size=self.initial_vocab_size,
            max_codebook_size=self.max_codebook_size,
            max_subtokens=self.max_subtokens,
            disabled_ids=self.disabled_ids,
        )

        return self.old_decode(
            token_ids, skip_special_tokens, clean_up_tokenization_spaces, **kwargs
        )

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: str,
        max_codebook_size: int,
        max_subtokens: int,
        disabled_ids: Optional[List[int]] = None,
        *args,
        **kwargs,
    ) -> PreTrainedTokenizerBase:
        tokenizer = AutoTokenizer.from_pretrained(
            pretrained_model_name_or_path, *args, **kwargs
        )

        return cls(
            tokenizer,
            len(tokenizer),
            max_codebook_size,
            max_subtokens,
            disabled_ids,
        )


if __name__ == "__main__":
    tokenizer = LZWTokenizer.from_pretrained(
        "microsoft/Phi-3.5-mini-instruct",
        max_codebook_size=2048,
        max_subtokens=4,
    )
    with open("fast_compression/src/lib.rs", "r") as f:
        text = f.read()
    compressed_ids = tokenizer.encode(text)
    assert tokenizer.decode(compressed_ids) == text
