from typing import List, Optional
import os
import json
import numpy as np
from tokenizers import ByteLevelBPETokenizer
from transformers import FSMTTokenizer
import logging

logger = logging.getLogger(__name__)


class BillingualTokenizer:
    """
    Use two tokenizers for source and target languages.
    Training is done in the corpus
    """

    def __init__(self, src_lang=None, tgt_lang=None, tok_dir=None, **kwargs):
        self.tok_dir = tok_dir
        self.src_tokenizer = ByteLevelBPETokenizer(
            vocab=os.path.join(tok_dir, src_lang, "vocab.json"),
            merges=os.path.join(tok_dir, src_lang, "merges.txt"),
        )
        self.tgt_tokenizer = ByteLevelBPETokenizer(
            vocab=os.path.join(tok_dir, tgt_lang, "vocab.json"),
            merges=os.path.join(tok_dir, tgt_lang, "merges.txt"),
        )

    @staticmethod
    def train_one(dataset, save_path, lang, vocab_size=52_000):
        def get_training_corpus(dataset, lang="en"):
            for start_idx in range(0, len(dataset), 1000):
                samples = dataset[start_idx : start_idx + 1000]["translation"]
                samples = [s[lang] for s in samples]
                yield samples

        training_corpus = get_training_corpus(dataset, lang)
        # Initialize a tokenizer
        tok_source = ByteLevelBPETokenizer()

        # Customize training
        tok_source.train_from_iterator(
            iterator=training_corpus,
            vocab_size=vocab_size,
            min_frequency=2,
            special_tokens=[
                "<s>",
                "<pad>",
                "</s>",
                "<unk>",
                "<mask>",
            ],
        )
        lang_dir = os.path.join(save_path, lang)
        os.makedirs(lang_dir, exist_ok=True)
        tok_source.save_model(lang_dir)

    @staticmethod
    def train(dataset, save_dir, src_lang, tgt_lang):
        BillingualTokenizer.train_one(dataset, save_dir, lang=src_lang)
        BillingualTokenizer.train_one(dataset, save_dir, lang=tgt_lang)

    @property
    def pad_token(self):
        return "<pad>"

    @property
    def unk_token(self):
        return "<unk>"

    @property
    def bos_token(self):
        return "<s>"

    @property
    def eos_token(self):
        return "</s>"

    @property
    def mask_token(self):
        return "<mask>"

    @property
    def pad_token_id(self):
        return self.src_tokenizer.token_to_id(self.pad_token)

    @property
    def unk_token_id(self):
        return self.src_tokenizer.token_to_id(self.unk_token)

    @property
    def bos_token_id(self):
        return self.src_tokenizer.token_to_id(self.bos_token)

    @property
    def eos_token_id(self):
        return self.src_tokenizer.token_to_id(self.eos_token)

    @property
    def mask_token_id(self):
        return self.src_tokenizer.token_to_id(self.mask_token)

    @property
    def src_vocab_size(self):
        return self.src_tokenizer.vocab_size

    @property
    def tgt_vocab_size(self):
        return self.tgt_tokenizer.vocab_size

    def __call__(
        self,
        text,
        text_target=None,
        max_length=100,
        padding=False,
        truncation=True,
        add_special_tokens=True,
        return_tensors="np",
        **kwargs,
    ):
        if isinstance(text, str):
            text = [text]
        if isinstance(text_target, str):
            text_target = [text_target]

        if truncation:
            self.src_tokenizer.enable_truncation(max_length=max_length)
            self.tgt_tokenizer.enable_truncation(max_length=max_length)
        if padding:
            self.tgt_tokenizer.enable_padding(
                pad_token=self.pad_token, pad_id=self.pad_token_id, length=max_length
            )
            self.src_tokenizer.enable_padding(
                pad_token=self.pad_token, pad_id=self.pad_token_id, length=max_length
            )

        output = dict()
        res = self.src_tokenizer.encode_batch(
            text, is_pretokenized=False, add_special_tokens=add_special_tokens
        )
        output["input_ids"] = [x.ids for x in res]
        output["attention_mask"] = [x.attention_mask for x in res]

        if not text_target is None:
            tgt_res = self.tgt_tokenizer.encode_batch(
                text_target,
                is_pretokenized=False,
                add_special_tokens=add_special_tokens,
            )
            output["labels"] = [x.ids for x in tgt_res]
            output["target_attention_mask"] = [x.attention_mask for x in tgt_res]

        if truncation:
            self.src_tokenizer.no_truncation()
            self.tgt_tokenizer.no_truncation()

        if padding:
            self.src_tokenizer.no_padding()
            self.tgt_tokenizer.no_padding()
        if return_tensors == "np":
            output = {k: np.array(v) for k, v in output.items()}
        return output

    def decode_batch(
        self,
        sequences: List[List[int]],
        is_target: Optional[bool] = True,
        skip_special_tokens: Optional[bool] = False,
    ):
        if is_target:
            return self.tgt_tokenizer.decode_batch(sequences, skip_special_tokens)
        return self.src_tokenizer.decode_batch(sequences, skip_special_tokens)


def get_pairs(word):
    """
    Return set of symbol pairs in a word. word is represented as tuple of symbols (symbols being variable-length
    strings)
    """
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs


def convert_fastbpe_json(vocab_file, save_path):
    """
    Converts fastBPE file formats to hugginface format:
        - used by FSMTTokenizer
    """
    with open(vocab_file, encoding="utf-8") as vocab_handle:
        lines = vocab_handle.readlines()
        keys = [l.split(" ")[0] for l in lines]

    for i, k in enumerate(keys):
        if k[-2:] == "@@":
            keys[i] = k[:-2]
        else:
            keys[i] = k + "</w>"

    vocab = dict(zip(keys, range(4, 4 + len(keys))))
    special_tokens = {"<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3}
    vocab.update(special_tokens)

    with open(save_path, "w", encoding="utf-8") as f:
        f.write(json.dumps(vocab, indent=2, sort_keys=False, ensure_ascii=False) + "\n")


class FSBilingualTokenizer(FSMTTokenizer):
    """
    Uses the existent fairseq interface in huggingface.
    Assumes training is don e with fastbpe.
    Advatanges over BillingualTokenizer:
        - fairseq common in the nmt literature => easier to compare results
        - existent precleaning
    Overrides transformers/models/fsmt/tokenization_fsmt.py:
        - add start of sequence token
        - add static methods for training the tokenizer with fastbpe
    """

    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
        adding special tokens. A FAIRSEQ Transformer sequence has the following format:

        - single sequence: `<s> X </s>`
        - pair of sequences: `<s> A </s> B </s>`

        Args:
            token_ids_0 (`List[int]`):
                List of IDs to which the special tokens will be added.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
        """
        sep = [self.sep_token_id]
        bos = [self.bos_token_id]
        if token_ids_1 is None:
            return bos + token_ids_0 + sep
        return bos + token_ids_0 + sep + token_ids_1 + sep

    def get_special_tokens_mask(
        self,
        token_ids_0: List[int],
        token_ids_1: Optional[List[int]] = None,
        already_has_special_tokens: bool = False,
    ) -> List[int]:
        """
        Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
        special tokens using the tokenizer `prepare_for_model` method.

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.
            already_has_special_tokens (`bool`, *optional*, defaults to `False`):
                Whether or not the token list is already formatted with special tokens for the model.

        Returns:
            `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
        """

        if already_has_special_tokens:
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0,
                token_ids_1=token_ids_1,
                already_has_special_tokens=True,
            )
        if token_ids_1 is not None:
            return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
        return [1] + ([0] * len(token_ids_0)) + [1]

    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Create a mask from the two sequences passed to be used in a sequence-pair classification task. A FAIRSEQ
        Transformer sequence pair mask has the following format:

        ```
        0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
        | first sequence    | second sequence |
        ```

        If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).

        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.

        Returns:
            `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).

        Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An
        FAIRSEQ_TRANSFORMER sequence pair mask has the following format:
        """
        if token_ids_1 is None:
            # bos + seq + sep
            return (len(token_ids_0) + 2) * [0]
        return (len(token_ids_0) + 2) * [0] + (len(token_ids_1) + 1) * [1]
