# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


class SubwordSplitter(object):
    def process_line(self, string):
        raise NotImplementedError

    def split(self, string):
        raise NotImplementedError


class NoneWordSplitter(object):
    def __init__(self, model):
        pass

    def split(self, string):
        return [string]

    def process_line(self, string):
        return [string]

    def finished_word(self, string):
        return True

    def merge(self, list_of_string):
        return "".join(list_of_string)

    def last_full_word_step(self, tokens, step):
        return len(tokens)

    def end_idx_last_full_word(self, tokens):
        return len(tokens)


class BPEWordSplitter(object):
    # TODO: lock back here
    def __init__(self, model_path):
        super().__init__()
        from subword_nmt.apply_bpe import BPE
        with open(model_path) as f:
            self.model = BPE(f)

    def split(self, string):
        return self.model.process_line(string).split()

    def end_idx_last_full_word(self, tokens):
        # Begin of word indices
        bow_indices = [0] + [i + 1 for i, t in enumerate(tokens[1:]) if t[-2:] != '@@']

        if len(bow_indices) < 2:
            return 0
        else:
            return bow_indices[-1]

    def merge(self, list_of_string):
        return " ".join([item.replace("@@", "") for item in list_of_string])


class SentencePieceModelWordSplitter(object):
    def __init__(self, model_path):
        super().__init__()
        import sentencepiece as spm
        self.model = spm.SentencePieceProcessor()
        self.model.Load(model_path)

    def split(self, string):
        return self.model.EncodeAsPieces(string)

    def end_idx_last_full_word(self, tokens):
        # Begin of word indices
        bow_indices = [i for i, t in enumerate(tokens) if t[0] == '\u2581']

        if len(bow_indices) < 2:
            return 0
        else:
            return bow_indices[-1]

    def merge(self, list_of_string):
        return self.model.DecodePieces(list_of_string)


SPLITTER_DICT = {
    None: NoneWordSplitter,
    "BPE": BPEWordSplitter,
    "SentencePieceModel": SentencePieceModelWordSplitter,
}
