from __future__ import annotations

import abc
from typing import Any


class Tokenizer(abc.ABC):
    """Tokenizer class."""

    def __init__(
        self, tokenizer: Any, dictionary: dict[str, int], *args, **kwargs
    ) -> None:
        self.tokenizer = tokenizer
        self.dictionary = dictionary
        self.tokens: dict[int, str] = {idx: token for token, idx in dictionary.items()}

    def __len__(self) -> int:
        return max(self.tokens) + 1

    UNK_TOKEN = "<unk>"

    @property
    @abc.abstractmethod
    def unk_idx(self) -> int:
        """Return the unknown index."""

    @classmethod
    @abc.abstractmethod
    def build(cls, name_or_path: str, *args, **kwargs) -> Tokenizer:
        """Build an tokenizer class.

        Args:
            name_or_path (str): Model name or path.

        Returns:
            Tokenizer: This class.
        """

    @classmethod
    def init_worker(
        cls, name_or_path: str, *args, **kwargs
    ) -> TokenizerMultiprocessingWrapper:
        """Build an tokenizer class.

        Args:
            name_or_path (str): Model name or path.

        Returns:
            TokenizerMultiprocessingWrapper: Tokenizer wrapper.
        """
        return TokenizerMultiprocessingWrapper(cls.build(name_or_path, *args, **kwargs))

    @abc.abstractmethod
    def tokenize(self, line: str) -> list[str]:
        """Tokenize the input line.

        Args:
            line (str): An input line.

        Returns:
            list[str]: The tokenized line.
        """

    def encode(self, tokens: list[str]) -> list[int]:
        """Encode tokens into token IDs.

        Args:
            tokens (list[str]): Input tokens.

        Returns:
            list[int]: The token ID sequence.
        """
        return [self.dictionary.get(tok, self.unk_idx) for tok in tokens]

    def decode(self, indices: list[int]) -> list[str]:
        """Decode token IDs into tokens.

        Args:
            indices (list[int]): Input token IDs.

        Returns:
            list[str]: The token sequence.
        """
        return [self.tokens.get(idx, self.UNK_TOKEN) for idx in indices]

    def __call__(self, line: str) -> list[int]:
        """Tokenize and encode a line.

        Args:
            line (str): An input line.

        Returns:
            list[int]: The token ID sequence.
        """
        return self.encode(self.tokenize(line))


class TokenizerMultiprocessingWrapper:
    def __init__(self, tokenizer: Tokenizer) -> None:
        TokenizerMultiprocessingWrapper.tokenizer = tokenizer

    @classmethod
    def tokenize(cls, line: str) -> list[str]:
        """Tokenize the input line.

        Args:
            line (str): An input line.

        Returns:
            list[str]: The tokenized line.
        """
        return cls.tokenizer.tokenize(line)

    @classmethod
    def encode(cls, tokens: list[str]) -> list[int]:
        """Encode tokens into token IDs.

        Args:
            tokens (list[str]): Input tokens.

        Returns:
            list[int]: The token ID sequence.
        """
        return cls.tokenizer.encode(tokens)

    @classmethod
    def __call__(cls, line: str) -> list[int]:
        """Tokenize and encode a line.

        Args:
            line (str): An input line.

        Returns:
            list[int]: The token ID sequence.
        """
        return cls.tokenizer(line)
