from typing import List
import sentencepiece as spm
import logging
import torch
from omegaconf import DictConfig, OmegaConf

from pado.core.base.transform import PadoTransform
from pado.data.transforms import register_transform

__all__ = ["SentencePieceTokenizer"]

logger = logging.getLogger("pado")


@register_transform("SentencePieceTokenizer")
class SentencePieceTokenizer(PadoTransform):
    """
    SentencePiece, mostly used for language modeling and speech recognition.
    Should be already processed and vocabulary file *.model exist.
    """

    def __init__(self,
                 vocab_model: str, *,
                 add_bos: bool = True,
                 add_eos: bool = True,
                 enable_sampling: bool = False,
                 alpha: float = 0.1):
        super().__init__()
        if vocab_model[-6:] != ".model":
            raise ValueError(f"SentencePiece model {vocab_model} is invalid.")

        # should use same bos, eos token that are used to generate SPM model.

        self.tokenizer = spm.SentencePieceProcessor()
        self.tokenizer.Init(model_file=vocab_model, add_bos=add_bos, add_eos=add_eos,
                            enable_sampling=enable_sampling, alpha=alpha)

        self.unk_piece = self.tokenizer.IdToPiece(self.tokenizer.unk_id())  # <UNK>
        logger.debug(f"SentencePiece UNK: {self.unk_piece} (id: {self.tokenizer.unk_id()})")

    def encode(self, script: str, sampling=None) -> List[int]:
        script = script.replace(self.unk_piece, "⁇")  # pre-defined one
        encoding = self.tokenizer.Encode(script, enable_sampling=sampling)
        return encoding

    def decode(self, sequence: List[int]) -> str:
        decoding: str = self.tokenizer.Decode(sequence)
        decoding = decoding.replace("⁇", self.unk_piece).strip()
        return decoding

    def forward(self, script: str, sampling=None) -> torch.Tensor:
        encoding = self.encode(script, sampling=sampling)
        return torch.tensor(encoding, dtype=torch.long)

    @classmethod
    def from_config(cls, cfg: DictConfig) -> "SentencePieceTokenizer":
        cfg = OmegaConf.to_container(cfg, resolve=True)
        return cls(**cfg)
