"""Vocabulary for protein-bert models.

Since we are dealing with proteins, the underlying alphabet is fixed
to be 1-character amino acid abbreviations along with some special
characters.

It looks like Protein-BERT does character-level tokenization.
"""
from typing import Dict, List

from proteinbert import tokenization as og_tokenization
from transformers import BertTokenizer

###############################################################################


class ProteinBertTokenizer:
    """Tokenizer class for Protein-BERT.

    TODO: Make more compatable with the HuggingFace style.
    """

    def __init__(
        self,
        token_to_id: Dict[str, int],
        pad_token: str = '<PAD>',
        sos_token: str = '<START>',
        eos_token: str = '<END>',
        unk_token: str = '<OTHER>',
    ):
        self._token_to_id = {**token_to_id}
        self._id_to_token = {index: token for token, index in self._token_to_id.items()}

        self.pad_token = pad_token
        self.pad_token_id = self._token_to_id[pad_token]

        self.sos_token = sos_token
        self.sos_token_id = self._token_to_id[sos_token]

        self.eos_token = eos_token
        self.eos_token_id = self._token_to_id[eos_token]

        self.unk_token = unk_token
        self.unk_token_id = self._token_to_id[unk_token]

        self.vocab_size = len(self._token_to_id)

    def encode(self, s: str) -> List[int]:
        s = s.upper()
        return [
            self.sos_token_id,
            *[self._token_to_id.get(aa, self.unk_token_id) for aa in s],
            self.eos_token_id,
        ]

    # TODO: Stuff like padding


###############################################################################

DEFAULT_TOKENIZER = ProteinBertTokenizer(og_tokenization.token_to_index)
