from typing import Tuple, List, Optional
import torch
from omegaconf import DictConfig, OmegaConf

from pado.core.base.transform import PadoTransform
from pado.data.transforms import register_transform
from pado.data.dictionary import Dictionary, ENGLISH_GRAPHEMES

__all__ = ["GraphemeTokenizer"]


@register_transform("GraphemeTokenizer")
class GraphemeTokenizer(PadoTransform):

    def __init__(self,
                 vocab: Optional[Tuple[str]] = None, *,
                 add_bos: bool = True,
                 add_eos: bool = True,
                 pad_token: str = "<b>",
                 bos_token: str = "<s>",
                 eos_token: str = "</s>",
                 lowercase: bool = False):
        super().__init__()
        if vocab is None:
            vocab = ENGLISH_GRAPHEMES

        self.lowercase = lowercase
        self.add_bos = add_bos
        self.add_eos = add_eos
        self.pad_token = pad_token
        self.bos_token = bos_token
        self.eos_token = eos_token

        # handle special tokens
        # order: pad, bos, eos, ...
        if add_eos:
            vocab = (eos_token,) + vocab
        if add_bos:
            vocab = (bos_token,) + vocab
        vocab = (pad_token,) + vocab

        if lowercase:
            vocab = tuple(v.lower() for v in vocab)
        else:
            vocab = tuple(v.upper() for v in vocab)

        self.vocab = vocab
        self.dictionary = Dictionary(bos=bos_token, pad=pad_token, eos=eos_token)
        for c in vocab:
            self.dictionary.add_token(c, n=0)
        self.dictionary.finalize()  # order will be preserved.

    @property
    def vocab_size(self):
        return len(self.vocab)

    def encode(self, script: str) -> List[int]:
        script = script.strip().replace("\n", "").replace(" ", "_")
        if self.lowercase:
            script = script.lower()
        else:
            script = script.upper()

        chars = [c for c in script]
        if self.add_bos:
            if chars[0] != self.bos_token:
                chars = [self.bos_token] + chars
        if self.add_eos:
            if chars[-1] != self.eos_token:
                chars = chars + [self.eos_token]

        indices = [self.dictionary.get_token_idx(c, allow_unknown=False) for c in chars]
        return indices

    def decode(self, sequence: List[int]) -> str:
        chars = [self.dictionary.get_idx_token(i) for i in sequence]
        script = "".join(chars).replace("_", " ")
        script = script.replace(self.pad_token, "").replace(self.bos_token, "").replace(self.eos_token, "")
        return script

    @torch.no_grad()
    def forward(self, script: str) -> torch.Tensor:
        indices = self.encode(script)
        return torch.tensor(indices, dtype=torch.long)

    @classmethod
    def from_config(cls, cfg: DictConfig) -> "GraphemeTokenizer":
        cfg = OmegaConf.to_container(cfg, resolve=True)
        return cls(**cfg)
