# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from fairseq import file_utils
from fairseq.data.encoders import register_bpe

from .gpt2_bpe_utils import get_encoder


DEFAULT_ENCODER_JSON = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json"
DEFAULT_VOCAB_BPE = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe"


@register_bpe("gpt2")
class GPT2BPE(object):
    @staticmethod
    def add_args(parser):
        # fmt: off
        parser.add_argument('--gpt2-encoder-json', type=str,
                            default=DEFAULT_ENCODER_JSON,
                            help='path to encoder.json')
        parser.add_argument('--gpt2-vocab-bpe', type=str,
                            default=DEFAULT_VOCAB_BPE,
                            help='path to vocab.bpe')
        # fmt: on

    def __init__(self, args):
        encoder_json = file_utils.cached_path(
            getattr(args, "gpt2_encoder_json", DEFAULT_ENCODER_JSON)
        )
        vocab_bpe = file_utils.cached_path(
            getattr(args, "gpt2_vocab_bpe", DEFAULT_VOCAB_BPE)
        )
        self.bpe = get_encoder(encoder_json, vocab_bpe)

    def encode(self, x: str, **kwargs) -> str:
        return " ".join(map(str, self.bpe.encode(x)))

    def decode(self, x: str) -> str:
        return self.bpe.decode(
            [int(tok) if tok not in {"<unk>", "<mask>"} else tok for tok in x.split()]
        )

    def is_beginning_of_word(self, x: str) -> bool:
        return self.decode(x).startswith(" ")
