# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import regex

from fairseq.data.encoders import register_tagger


@register_tagger('tagger')
class Tagger(object):

    @staticmethod
    def add_args(parser):
        # fmt: off
        parser.add_argument('--lang-code', action='store_true', default=None,
                            help="prefix lines with the language code (e.g., '<lang:de>')")
        parser.add_argument('--target-lang-code', action='store_true', default=None,
                            help="prefix lines with the target language code (e.g., '<lang:en>')")
        parser.add_argument('--corpus-tag', nargs='?', const=True, default=False,
                            help="prefix lines with a corpus tag (e.g., '<corpus:europarl>'). "
                                 "When no value is provided the tag is inferred from the corpus metadata "
                                 "(at training time only)")
        parser.add_argument('--lang-code-prefix', default='lang:',
                            help="which prefix to use for source language codes (default: 'lang:')")
        parser.add_argument('--target-lang-code-prefix', default='lang:',
                            help="which prefix to use for target language codes (default: 'lang:')")
        parser.add_argument('--corpus-tag-prefix', default='corpus:',
                            help="which prefix to use for corpus tags (default: 'corpus:')")
        parser.add_argument('--append-codes', action='store_true', help="put codes at the end")
        parser.add_argument('--strip-tags', action='store_true',
                            help="remove lang codes and corpus tags from output")
        # fmt: on

    def __init__(self, args):
        self.args = args
        
        prefixes = [args.lang_code_prefix, args.target_lang_code_prefix, args.corpus_tag_prefix]
        prefixes = [regex.escape(prefix) for prefix in prefixes]
        self.strip_tags_regex = f"([▁\\s]*<({'|'.join(prefixes)}).*?>)+"
        self.strip_tags_regex = regex.compile(self.strip_tags_regex)

    def encode(self, x: str, meta: dict = {}, **kwargs) -> str:
        """
        Prepend tags to the text sequence x.
        Meta should contain the necessary metadata (lang code, target lang code, corpus tag)
        """
        tags = []
        
        if self.args.lang_code:
            tags.append(f"<{self.args.lang_code_prefix}{meta['lang']}>")

        if self.args.target_lang_code:
            tags.append(f"<{self.args.target_lang_code_prefix}{meta['tgt_lang']}>")

        if self.args.corpus_tag:
            corpus_tag = meta.get('corpus_tag') if self.args.corpus_tag is True else self.args.corpus_tag
            if corpus_tag is None:
                raise Exception('missing value for --corpus-tag')
            tags.append(f"<{self.args.corpus_tag_prefix}{corpus_tag}>")
        
        if self.args.append_codes:
            tags.insert(0, x)
        else:
            tags.append(x)

        return ' '.join(tags)

    def decode(self, x: str) -> str:
        if self.args.strip_tags:
            x = self.strip_tags(x)
        return x

    def strip_tags(self, x: str) -> bool:
        """
        Strip language codes and corpus tags at the beginning of x.
        They can be intermixed with any number of whitespaces and '▁'
        """
        return self.strip_tags_regex.sub('', x).strip()
