# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from fairseq import file_utils
from fairseq.data.encoders import register_bpe
import regex
import unicodedata
import os

lower, title, upper, other = range(4)
case_symbols = [None, '<T>', '<U>', None]


@register_bpe("sentencepiece")
class SentencepieceBPE(object):
    @staticmethod
    def add_args(parser):
        # fmt: off
        parser.add_argument('--sentencepiece-model', type=str,
                            help='path to sentencepiece model')
        parser.add_argument('--inline-case', '--inline-casing', action='store_true',
                            help='put the text in lowercase and use special symbols to '
                                 'indicate the case of the preceding wordpiece')
        # fmt: on

    def __init__(self, args):
        if args.sentencepiece_model is not None:
            sentencepiece_model = file_utils.cached_path(args.sentencepiece_model)
        else:
            sentencepiece_model = None
            for name in 'spm.model', 'sentencepiece.model':
                path = os.path.join(args.data, name)
                if os.path.isfile(path):
                    sentencepiece_model = path
            if sentencepiece_model is None:
                raise ValueError("Could not find any sentencepiece model")
        self.mixed_case_regex = regex.compile('(▁?[[:upper:]]?[^[:upper:]\s▁]+|▁?[[:upper:]]+|▁)')
        self.inline_case = args.inline_case
        try:
            import sentencepiece as spm

            self.sp = spm.SentencePieceProcessor()
            self.sp.Load(sentencepiece_model)
        except ImportError:
            raise ImportError(
                "Please install sentencepiece with: pip install sentencepiece"
            )

    def is_beginning_of_word(self, x: str) -> bool:
        if x in ["<unk>", "<s>", "</s>", "<pad>"]:
            # special elements are always considered beginnings
            # HACK: this logic is already present in fairseq/tasks/masked_lm.py
            # but these special tokens are also contained in the sentencepiece
            # vocabulary which causes duplicate special tokens. This hack makes
            # sure that they are all taken into account.
            return True
        return x.startswith("▁")

    @staticmethod
    def clean(line):
        return regex.sub(r'\s+', ' ', line).strip()

    def get_case(self, s):
        if s.istitle():
            return title
        if s.isupper():
            return upper
        elif s.islower() or s.lower() == s:
            return lower
        else:
            return other

    def _encode(self, x: str) -> str:
        pieces = []
        for piece in self.sp.EncodeAsPieces(x):
            if self.sp.IsUnknown(self.sp.PieceToId(piece)):
                pieces += list(piece)
            else:
                pieces.append(piece)
        return ' '.join(pieces)

    def encode(self, x: str, **kwargs) -> str:
        if not self.inline_case:
            return self._encode(x)

        orig = self.clean(unicodedata.normalize('NFKC', x))
        orig_lower = ' '.join(y if len(x) == len(y) else x for x, y in ((w, w.lower()) for w in orig.split()))
        # only lowercase words whose length is not modified by lowercasing
        line = self.clean(self._encode(orig_lower))

        output = []
        j = 0
        for wordpiece in line.split():
            if wordpiece == '▁':
                output.append(wordpiece)
                continue

            prefix = ''
            try:
                if wordpiece.startswith('▁'):
                    prefix = '▁'
                    wordpiece = wordpiece[1:]
                i = orig_lower.find(wordpiece, j)
            except:
                output.append(prefix + wordpiece)
                continue

            j = i + len(wordpiece)
            cased = orig[i:j]

            case = self.get_case(cased)
            if len(cased) == len(wordpiece) and case == other:
                cased_split = self.mixed_case_regex.findall(cased)
                k = 0
                for n, s in enumerate(cased_split):
                    case = self.get_case(s)
                    output += [(prefix if n == 0 else '') + wordpiece[k:k + len(s)], case_symbols[case]]
                    k += len(s)
            else:
                output += [prefix + wordpiece, case_symbols[case]]

        return ' '.join(w for w in output if w is not None)

    def decode(self, x: str, **kwargs) -> str:
        if self.inline_case:
            tokens = x.split()
            for i, w in enumerate(tokens):
                if w == case_symbols[title]:
                    tokens[i - 1] = tokens[i - 1].title()
                elif w == case_symbols[upper]:
                    tokens[i - 1] = tokens[i - 1].upper()

            x = ' '.join(w for w in tokens if w not in case_symbols)
        return x.replace(' ', '').replace('▁', ' ').strip()
