#!/usr/bin/env python3

import sys
import os
import argparse
import regex
import json
import itertools
import unicodedata
import multiprocessing
import numpy as np
import collections


class BPE(object):
    def __init__(self, codes, merges=-1, vocab=None,
                 lowercase=False, legacy=False,
                 inline_case=False, protect_regex=None, nfkc=False, **kwargs):
        self.bpe_codes = [tuple(item) for n, item in enumerate(codes) if n < merges or merges == -1]

        self.lowercase = lowercase
        self.legacy = legacy
        self.nfkc = nfkc
        self.inline_case = inline_case
        self.protect_regex = None if protect_regex is None else regex.compile(protect_regex)

        for i, item in enumerate(self.bpe_codes):
            if len(item) != 2:
                print('Error: invalid line {} in BPE codes file: {}'.format(i + 1, ' '.join(item)), file=sys.stderr)
                print('The line should consist of exactly two subword units, separated by whitespace', file=sys.stderr)
                sys.exit(1)

        # some hacking to deal with duplicates (only consider first instance)
        self.bpe_codes = {code: i for i, code in reversed(list(enumerate(self.bpe_codes)))}
        self.bpe_codes_reverse = {a + b: (a, b) for a, b in self.bpe_codes}
        self.separator = '@@' if self.legacy else ''
        self.vocab = vocab

        self.meta_symbol = '▁'
        self.protect_symbol = '╳'
        assert not (self.legacy and self.inline_case), 'not supported'

        self.whitespace_regex = regex.compile(r'\s+')
        self.no_mixed_case_regex = regex.compile(
            '({0}?[[:upper:]]?[^[:upper:]\s{0}{1}]+|{0}?[[:upper:]]+|{0}|{1})'.format(
                regex.escape(self.meta_symbol),
                regex.escape(self.protect_symbol)
            )
        )
        self.sentencepiece_regex = regex.compile(
            '({0}?[^\s{0}{1}]+|{0}|{1})'.format(
                regex.escape(self.meta_symbol),
                regex.escape(self.protect_symbol)
            )
        )
        self.upper_code, self.title_code, self.lower_code = range(3)
        self.case_symbols = ['<U>', '<T>', None]
        self.cache = collections.OrderedDict()  # must be ordered to work as an LRU cache
        
    def segment(self, sentence, dropout=0, spell_out=0):
        """segment single sentence (whitespace-tokenized string) with BPE encoding"""
        sentence = sentence.strip()

        if not sentence:
            return sentence

        if self.nfkc:
            sentence = unicodedata.normalize('NFKC', sentence)

        if self.protect_regex is not None:
            sentence = sentence.replace(self.protect_symbol, ' ')
            protected_tokens = [
                m.group(0) for m in self.protect_regex.finditer(sentence)
            ]
            sentence = self.protect_regex.sub(self.protect_symbol, sentence)

        if self.legacy:
            tokens = sentence.split()
            if self.lowercase:
                tokens = [token.lower() for token in tokens]
        else:
            if self.inline_case:
                for symbol in self.case_symbols:
                    if symbol is not None:
                        sentence = sentence.replace(symbol, ' ')

            sentence = sentence.replace(self.meta_symbol, ' ')
            sentence = self.meta_symbol + self.whitespace_regex.sub(self.meta_symbol, sentence)
            if self.inline_case:
                tokens = self.no_mixed_case_regex.findall(sentence)
            else:
                tokens = self.sentencepiece_regex.findall(sentence)

            if self.lowercase or self.inline_case:
                cased_tokens = tokens
                tokens = [token.lower() for token in tokens]

        segments = self.segment_tokens(tokens, dropout=dropout, spell_out=spell_out)

        if not self.legacy and self.inline_case:
            segments_ = []
            for cased_token, segment in zip(cased_tokens, segments):
                i = 0
                segment_ = []
                for out in segment:
                    x = cased_token[i:i + len(out)]
                    i += len(out)
                    if x.isupper():
                        segment_.append((out, self.upper_code))
                    elif x.istitle():
                        segment_.append((out, self.title_code))
                    else:
                        segment_.append((out, self.lower_code))
                segments_.append(segment_)

            segments = [
                ' '.join(self.add_factor(token, case, i == len(segment) - 1) for i, (token, case) in enumerate(segment))
                for segment in segments_
            ]
        else:
            sep = self.separator + ' '
            segments = [sep.join(segment) for segment in segments]

        sentence = ' '.join(segments)

        if self.protect_regex is not None:
            sentence = sentence.replace(self.protect_symbol + ' ▁ ', self.protect_symbol + ' ')   # FIXME
            for token in protected_tokens:
                sentence = sentence.replace(self.protect_symbol, token, 1)
            sentence = self.whitespace_regex.sub(' ', sentence)

        if not self.legacy and sentence.startswith(self.meta_symbol + ' '):
            # a lone meta symbol at the beginning of a sentence serves no purpose
            sentence = sentence[len(self.meta_symbol + ' '):]

        return sentence

    def add_factor(self, token, case, last=False):
        if not last:
            token += self.separator
        if self.inline_case:
            case_symbol = self.case_symbols[case]
            if case_symbol is not None:
                token += ' ' + case_symbol
        return token

    def segment_token(self, word, dropout=0):
        return encode(
            word, self.bpe_codes, self.bpe_codes_reverse,
            self.vocab, self.separator, self.legacy, dropout
        )

    def segment_token_cached(self, word, dropout=0):
        # simple LRU cache implementation (functools.lru_cache is not pickable)
        if word in self.cache:
            new_word = self.cache.pop(word)
            self.cache[word] = new_word   # put this entry last in the cache
            return new_word
        else:
            new_word = self.segment_token(word, dropout)
            self.cache[word] = new_word
            if len(self.cache) > 2**20:
                word = next(iter(self.cache.keys()))   # delete first (oldest) entry
                self.cache.pop(word)
            return new_word

    def segment_tokens(self, tokens, dropout=0, spell_out=0):
        """segment a sequence of tokens with BPE encoding"""
        output = []
        for word in tokens:
            # eliminate double spaces
            if not word:
                output.append([])
                continue

            if spell_out and np.random.random() < spell_out:
                new_word = list(word)
            elif dropout:
                new_word = self.segment_token(word, dropout)
            else:
                new_word = self.segment_token_cached(word, dropout)

            output.append(new_word)

        return output

    @staticmethod
    def process_(buffer):
        global bpe_, dropout_
        output = []
        for line in buffer:
            output.append(bpe_.segment(line, dropout=dropout_) + '\n')
        return output

    def process_file(self, input_, threads, buffer_size, dropout=0):
        # set `self` as global: much faster for multiprocessing than giving
        # `self` as a parameter to each process
        global bpe_, dropout_
        bpe_ = self
        dropout_ = dropout

        if threads is None or threads > 1:
            pool = multiprocessing.Pool(processes=threads)
        else:
            pool = None

        buffer = []
        while True:
            break_ = False
            try:
                buffer.append(next(input_))
            except StopIteration:
                break_ = True

            if len(buffer) >= buffer_size or break_ and buffer:
                if pool is not None:
                    n = 1 + len(buffer) // (pool._processes * 8)

                    output = pool.map(self.process_,
                        [buffer[i:i + n] for i in range(0, len(buffer), n)],
                        chunksize=8,
                    )
                    for lines in output:
                        for line in lines:
                            yield line
                else:
                    for line in self.process_(buffer):
                        yield line

                buffer = []

            if break_:
                break


def create_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--input', '-i', default=[None],
        metavar='PATH', nargs='+',
        help='Input file (default: standard input).')
    parser.add_argument(
        '--codes', '-c', metavar='PATH',
        help="File with BPE codes (created by learn_bpe.py).", dest='codes_named')
    parser.add_argument(
        'codes', nargs='?', metavar='BPE_CODES',
        help='File with BPE codes (created by learn_bpe.py).')
    parser.add_argument(
        '--merges', '-m', type=int, default=-1,
        metavar='INT',
        help='Use this many BPE operations (<= number of learned symbols) '
             'default: Apply all the learned merge operations')
    parser.add_argument(
        '--output', '-o', default=[None],
        metavar='PATH', nargs='+',
        help='Output file (default: standard output)')
    parser.add_argument(
        '--vocabulary', '--vocab', '-v', nargs='+',
        metavar='PATH',
        help='Vocabulary file (built with get_vocab.py). If provided, this script reverts any merge operations that produce an OOV.')
    parser.add_argument(
        '--vocabulary-threshold', '--threshold', '-t', type=int,
        metavar='INT',
        help='Vocabulary threshold. If vocabulary is provided, any word with frequency < threshold will be treated as OOV')
    parser.add_argument('--no-mixed-case', action='store_true', help='Subwords cannot be mixed-case (they must be either upper, lower or title case)')
    parser.add_argument('--lowercase', action='store_true', help='Lowercase all text before applying the BPE model')
    parser.add_argument('--legacy', action='store_true', help='Legacy mode: like subwords-nmt "@@" is used to mark prefixes (default: SentencePiece-like behavior)')
    parser.add_argument('--factor-case', action='store_true', help='Output case information as a separate feature (e.g., "MacDonald\'s" -> "mac|T don|T ald|L \'s|L")')
    parser.add_argument('--inline-case', action='store_true', help='Output case information as a separate token (e.g., "MacDonald\'s" -> "mac <T> don <T> ald \'s")')
    parser.add_argument('--protect-regex', help='The strings matching this regular expression will not be affected by the BPE')
    parser.add_argument('--nfkc', action='store_true', help='Perform Unicode NFKC normalization')
    parser.add_argument('--threads', type=int, help='Spawn that many Python processes (default: CPU count)')
    parser.add_argument('--buffer-size', type=int, default=10000, help='Process this many lines at once (necessary for multi-threading, default: 10000)')
    parser.add_argument('-d', '--dropout', '--bpe-dropout', type=float, default=0, help='Apply BPE dropout with this probability')
    parser.add_argument('--seed', '--random-seed', type=int, help='Random seed for BPE dropout')
    return parser


def encode(orig, bpe_codes, bpe_codes_reverse, vocab, separator, legacy, dropout):
    """Encode word based on list of BPE merge operations, which are applied consecutively
    """
    word = list(orig)
    if legacy:
        word[-1] = word[-1] + '</w>'

    while len(word) > 1:
        pairs = {pair: None for pair in zip(word, word[1:]) if pair in bpe_codes}
        # using dict instead of set, because set has a non-deterministic order

        if dropout:
            pairs = {pair: None for pair in pairs if np.random.random() > dropout}

        if not pairs:
            break

        bigram = min(pairs, key=lambda pair: bpe_codes.get(pair, float('inf')))
        if bigram not in bpe_codes:
            break
        first, second = bigram

        new_word = []
        skip = False
        for x, y in zip(word, word[1:]):
            if skip:
                skip = False
                continue

            if x == first and y == second:
                new_word.append(x + y)
                skip = True
            else:
                new_word.append(x)
        if not skip:
            new_word.append(y)

        word = new_word

    # don't print end-of-word symbols
    if legacy:
        if word[-1] == '</w>':
            word.pop(-1)
        elif word[-1].endswith('</w>'):
            word[-1] = word[-1].replace('</w>', '')

    if vocab:
        word = check_vocab_and_split(word, bpe_codes_reverse, vocab, separator, legacy)

    return word


def recursive_split(segment, bpe_codes, vocab, separator, final=False, legacy=False):
    """Recursively split segment into smaller units (by reversing BPE merges)
    until all units are either in-vocabulary, or cannot be split futher."""
    try:
        if final and legacy:
            left, right = bpe_codes[segment + '</w>']
            right = right[:-4]
        else:
            left, right = bpe_codes[segment]
    except:
        #sys.stderr.write('cannot split {0} further.\n'.format(segment))
        yield segment
        return

    if left + separator in vocab:
        yield left
    else:
        for item in recursive_split(left, bpe_codes, vocab, separator, False, legacy):
            yield item

    if final and right in vocab or not final and right + separator in vocab:
        yield right
    else:
        for item in recursive_split(right, bpe_codes, vocab, separator, final, legacy):
            yield item


def check_vocab_and_split(orig, bpe_codes, vocab, separator, legacy=False):
    """Check for each segment in word if it is in-vocabulary,
    and segment OOV segments into smaller units by reversing the BPE merge operations"""
    out = []
    for i, segment in enumerate(orig):
        last = i == len(orig) - 1
        if (not last and segment + separator in vocab) or (last and segment in vocab):
            out.append(segment)
        else:
            #sys.stderr.write('OOV: {0}\n'.format(segment))
            for item in recursive_split(segment, bpe_codes, vocab, separator, last, legacy):
                out.append(item)
    return out


def read_vocabulary(filename, threshold):
    """read vocabulary file produced by get_vocab.py, and filter according to frequency threshold.
    """
    with open(filename) as vocab_file:
        vocabulary = set()
        for line in vocab_file:
            m = regex.match(r'(.*)[\t ](\d+)$', line)
            word = m.group(1)
            freq = m.group(2)
            if threshold is None or int(freq) >= threshold:
                vocabulary.add(word)
        return vocabulary


def read_bpecodes(filename):
    config = {}
    with open(filename) as f:
        first_line = next(f)
        if first_line.startswith('#'):
            try:
                config = json.loads(first_line.strip('# \n\r'))
            except:
                pass
        else:
            f = itertools.chain([first_line], f)

        bpecodes = [tuple(line.rstrip('\r\n').rsplit(' ', maxsplit=1)) for line in f]
    return config, bpecodes


if __name__ == '__main__':
    parser = create_parser()
    args = parser.parse_args()

    np.random.seed(args.seed)

    if args.codes is None:
        args.codes = args.codes_named

    if len(args.output) == 1:
        args.output = args.output * len(args.input)

    if args.codes is None:
        parser.print_usage(sys.stderr)
        print('apply_bpe.py: error: the following arguments are required: BPE_CODES', file=sys.stderr)
        sys.exit(1)
    if len(args.input) != len(args.output):
        parser.print_usage(sys.stderr)
        print('apply_bpe.py: error: --input must have the same number of arguments as --output or one', file=sys.stderr)
        sys.exit(1)
    if args.vocabulary is not None and len(args.vocabulary) != len(args.input) and len(args.vocabulary) != 1:
        parser.print_usage(sys.stderr)
        print('apply_bpe.py: error: --vocabulary must have the same number of arguments as --input or one', file=sys.stderr)
        sys.exit(1)
    if any(input_ is not None and input_ in args.output for input_ in args.input):
        parser.print_usage(sys.stderr)
        print('apply_bpe.py: error: --output must be different than --input', file=sys.stderr)
        sys.exit(1)

    config, bpecodes = read_bpecodes(args.codes)

    for k, v in config.items():
        setattr(args, k, getattr(args, k, None) or v)

    bpe = BPE(bpecodes, args.merges, None,
              args.lowercase, args.legacy, args.inline_case,
              args.protect_regex, args.nfkc)

    for output_ in args.output:
        if output_ is not None:
            try:
                os.unlink(output_)
            except FileNotFoundError:
                pass

    for i, (input_, output_) in enumerate(zip(args.input, args.output)):
        if args.vocabulary and args.vocabulary_threshold and (i == 0 or len(args.vocabulary) > 1):
            vocabulary = read_vocabulary(args.vocabulary[i], args.vocabulary_threshold)
            bpe.vocab = vocabulary
            bpe.cache = {}

        input_ = sys.stdin if input_ is None else open(input_)
        output_ = sys.stdout if output_ is None else open(output_, 'a')

        for line in bpe.process_file(input_, threads=args.threads, buffer_size=args.buffer_size, dropout=args.dropout):
            output_.write(line)

        if input_ is not sys.stdin:
            input_.close()
        if output_ is not sys.stdout:
            output_.close()
