import re
import os
import numpy as np
from collections import defaultdict
from fairseq import file_utils, utils
from fairseq.data.encoders import register_bpe, apply_bpe
from fairseq.data import Dictionary


@register_bpe('nle_bpe')
class NLE_BPE(object):

    @staticmethod
    def add_args(parser):
        # fmt: off
        parser.add_argument('--bpe-codes', '--bpecodes', nargs='+', help='path to BPE codes generated by learn_bpe.py')
        parser.add_argument('--inline-case', '--inline', action='store_true', default=None, help='split words and their case')
        parser.add_argument('--bpe-vocab', nargs='+', help='path to word piece counts generated by learn_bpe.py')
        parser.add_argument('--bpe-threshold', type=int, help='don\t output word pieces whose frequency is lower than this')
        parser.add_argument('--legacy', action='store_true', default=None, help='subword-nmt behavior')
        parser.add_argument('--bpe-dropout', type=float, default=0, help='apply source-side BPE dropout with this rate')
        parser.add_argument('--target-bpe-dropout', type=float, default=0, help='apply target-side BPE dropout with this rate')
        # fmt: on

    def __init__(self, args):
        """
        By default, NLE-BPE looks for files starting with 'bpecodes' inside args.data.
        
        If the filename has an extension (e.g., 'bpecodes.de'), then this extension is interpreted as a language code, and this BPE model will only
        be used to process lines in this language. If it has no extension (e.g., 'bpecodes'), then it can be used for any language.

        The same is done with BPE vocabularies for frequency threshold filtering (files starting with 'bpe-vocab').
        If a language has a BPE vocabulary, but no BPE model, then the  generic BPE model will be used in combination with the provided BPE vocabulary.

        Custom paths can be given with --bpe-codes and --bpe-vocab, in which case NLE-BPE will only look at this files and interpret their file extensions
        as language codes like explained above.

        When BPE threshold is set to a positive value, only wordpieces whose frequency in 'bpe-vocab' is above this threshold will be generated.
        If it is set to zero, only wordpieces that appear in the vocabulary will be produced. If it is not set, the vocabulary will be ignored.

        Legacy mode (--legacy) uses the old-fashioned subword-nmt style, with '@@' wordpiece suffixes.
        By default NLE-BPE uses '▁' prefixes that replace whitespaces, like SentencePiece.

        Inline casing lowercases the input text, but adds a tag after each wordpiece to specify its original case: <U> for capitalized and <T> for title
        case (i.e., starting with a capitalized letter). Lowercase wordpieces are not followed by a tag.
        Tokens that do not fall into one of these cases (e.g., 'YouTube') are split before being segmented into wordpieces ('You', 'Tube').

        Legacy mode and inline casing can be turned on automatically by the header inside the 'bpecodes' file.
        """
        data_path = utils.split_paths(args.data)[0]

        def find_files(file_list, regex=None):
            # NLE-BPE interprets file extensions as language codes
            files = {}
            for path in file_list:
                name = os.path.basename(path)
                if regex is None or re.match(regex, name):
                    name, *lang = name.rsplit('.', maxsplit=1)
                    lang = lang[0] if lang else None
                    files[lang] = path
            return files

        all_files = [os.path.join(data_path, name) for name in os.listdir(data_path)]

        # either the provided files or all files starting with 'bpe-codes' or 'bpecodes'
        if args.bpe_codes:
            bpe_codes_by_lang = find_files(args.bpe_codes)
        else:
            bpe_codes_by_lang = find_files(all_files, regex=r'bpe-codes|bpecodes')
        
        # either the provided files or all files starting with 'bpe-vocab' or 'bpevocab'
        if args.bpe_threshold is None:
            bpe_vocab_by_lang = {}
        else:
            if args.bpe_vocab:
                bpe_vocab_by_lang = find_files(args.bpe_vocab)
            else:
                bpe_vocab_by_lang = find_files(all_files, regex=r'bpe-vocab|bpevocab')

        langs = set(list(bpe_codes_by_lang) + list(bpe_vocab_by_lang))

        self.bpe_models = {}
        self.args = args

        for lang in langs:
            # bpe_codes_by_lang[None] is the default model, used for languages that don't have 
            # a language-specific model
            bpe_codes = bpe_codes_by_lang.get(lang) or bpe_codes_by_lang.get(None)
            bpe_vocab = bpe_vocab_by_lang.get(lang) or bpe_vocab_by_lang.get(None)

            config, codes = apply_bpe.read_bpecodes(bpe_codes)

            config['inline_case'] = config.get('inline_case', False) or args.inline_case
            config['legacy'] = config.get('legacy', False) or args.legacy

            vocab = None if bpe_vocab is None else apply_bpe.read_vocabulary(bpe_vocab, args.bpe_threshold)
            bpe_model = apply_bpe.BPE(codes=codes, vocab=vocab, **config)
            self.bpe_models[lang] = bpe_model

        assert len(self.bpe_models) > 0, 'found no BPE model'

    def encode(self, x: str, meta: dict = {}, **kwargs) -> str:
        lang = meta.get('lang')
        bpe_dropout = 0
        if getattr(self, 'train', True):
            if lang == meta.get('src_lang'):
                bpe_dropout = self.args.bpe_dropout
            elif lang == meta.get('tgt_lang'):
                bpe_dropout = self.args.target_bpe_dropout

        bpe_model = self.bpe_models.get(lang) or self.bpe_models.get(None)
        if bpe_model is None:
            raise Exception(f'no BPE model is defined for this lang ({lang}) and there is no default BPE model')
        return bpe_model.segment(x, dropout=bpe_dropout)

    def decode(self, x: str, **kwargs) -> str:
        tokens = x.split()
        for i, w in enumerate(tokens):
            if w == '<T>':
                tokens[i - 1] = tokens[i - 1].title()
            elif w == '<U>':
                tokens[i - 1] = tokens[i - 1].upper()

        x = ' '.join(w for w in tokens if w != '<T>' and w != '<U>')

        if self.args.legacy:
            x = x.replace('@@ ', '').replace('@@', '').strip()
        else:
            x = x.replace(' ', '').replace('▁', ' ').strip()

        return x
