# BPE tokenization for subword regularization
# https://github.com/VProv/BPE-Dropout/

import heapq
import numpy as np


def load_subword_nmt_table(path):
    """
    :param path: path to merge_table with subword-nmt format
    """
    table = dict()
    cur_priority = 1
    with open(path) as f:
        for line in f:
            if '#version' in line:
                continue
            token_1, token_2 = line.rstrip('\n').split(' ')
            table[(token_1, token_2)] = int(cur_priority)
            cur_priority += 1 
    return table


def load_merge_table(path):
    """
    :param path: path to merge_table
    """
    table = dict()
    with open(path) as f:
        for line in f:
            token_1, token_2, priority = line.split('\t')
            table[(token_1, token_2)] = int(priority)
            
    return table


def tokenize_word(merge_rules, word, dropout=0.0, 
                  random_generator=np.random.RandomState(), 
                  sentinels=['^', '$'],
                  regime='begin',
                  bpe_symbol='`',
                  always_merge_sentinels=True):
    """ Tokenize word using bpe merge rules
    :param merge_rules: dict [(a,b)] -> id, merge table, ids are in increasing order
    :param word: string
    :param dropout: float, dropout rate
    :param random_generator: random generator with .rand() method
    :param sentinels: list of two strings, beginning of word sentinel and end of word sentinel (empty string means that no corresponding sentinel is applied)
    :param regime:
        'begin' -- add bpe symbol to the beginning of bpe token
        'end' -- add bpe symbol to the end of bpe token
    :param bpe_symbol: str, could be one of '`', '@@', '▁'
    :param always_merge_sentinels: bool, if True, sentinels are always concatenated 
        to the first and last characters before applying BPE merges (True is equivalent to subword-nmt>=0.2, False is equivalent to subword-nmt<0.2)
    """
    
    # Subword tokens
    sw_tokens = list(word)

    # Add sentinels
    if always_merge_sentinels:
        sw_tokens = [sentinels[0] + sw_tokens[0]] + sw_tokens[1:]
        sw_tokens = sw_tokens[:-1] + [sw_tokens[-1] + sentinels[1]]
    else:
        beg_sentinel = [sentinels[0]] if len(sentinels[0]) > 0 else []
        end_sentinel = [sentinels[1]] if len(sentinels[1]) > 0 else []
        sw_tokens = beg_sentinel + sw_tokens + end_sentinel

    # Add start merges
    # Heap with pairs (priority, position)
    merge_heap = []

    for pos in range(len(sw_tokens) - 1):
        cur_nxt_pair = (sw_tokens[pos], sw_tokens[pos + 1])
        if cur_nxt_pair in merge_rules:
            cur_priority = merge_rules[cur_nxt_pair]
            merge_heap.append([cur_priority, pos])

    heapq.heapify(merge_heap)

    sw_length = len(sw_tokens)
    dropped_merges = []

    while len(merge_heap):
        cur_priority, cur_pos = heapq.heappop(merge_heap)

        # Delete not valid merges
        if cur_pos > sw_length - 2:
            continue
        cur = sw_tokens[cur_pos]
        nxt = sw_tokens[cur_pos + 1]

        if merge_rules.get((cur, nxt), None) != cur_priority:
            continue

        # Apply dropout
        if random_generator.rand() < dropout:
            dropped_merges.append([cur_priority, cur_pos])
            continue

        sw_tokens[cur_pos:cur_pos + 2] = [cur + nxt]
        sw_length -= 1

        for pair in merge_heap:
            if pair[1] > cur_pos:
                pair[1] -= 1

        # Add dropped merges back
        for priority, position in dropped_merges:
            if position > cur_pos:
                position -= 1
            heapq.heappush(merge_heap, [priority, position])

        dropped_merges = []

        # Add new possible merge
        new_cur = sw_tokens[cur_pos]
        if cur_pos > 0:
            prev = sw_tokens[cur_pos - 1]
            if (prev, new_cur) in merge_rules:
                heapq.heappush(merge_heap, [merge_rules[(prev, new_cur)], cur_pos - 1])

        if cur_pos < (sw_length - 1):
            new_next = sw_tokens[cur_pos + 1]
            if (new_cur, new_next) in merge_rules:
                heapq.heappush(merge_heap, [merge_rules[(new_cur, new_next)], cur_pos])
    
    
    sw_tokens[0] = sw_tokens[0].replace(sentinels[0], '')
    sw_tokens[-1] = sw_tokens[-1].replace(sentinels[1], '')
    
    if regime == 'begin':
        for i in range(1, sw_length):
            sw_tokens[i] = bpe_symbol + sw_tokens[i]
            
        if sw_tokens[0] == '':
            sw_tokens = sw_tokens[1:]
            sw_tokens[0] = sw_tokens[0].lstrip(bpe_symbol)
        if sw_tokens[-1] == bpe_symbol:
            sw_tokens.pop()
    elif regime == 'end':
        for i in range(sw_length -1):
            sw_tokens[i] = sw_tokens[i] + bpe_symbol
        if sw_tokens[0] == bpe_symbol:
            sw_tokens.pop(0)
        if sw_tokens[-1] == '':
            sw_tokens = sw_tokens[:-1]
            sw_tokens[-1] = sw_tokens[-1].rstrip(bpe_symbol)
        
    return sw_tokens


def tokenize_text(rules, line, dropout=0.0, random_generator=np.random.RandomState(), **args):
    # In case of '  ' (two spaces) in GCG
    word_lst = line.split(' ')
    word_lst = [w for w in word_lst if len(w) > 0]
    return ' '.join([' '.join(tokenize_word(rules, word, dropout, random_generator, **args)) for word in word_lst])


class BpeOnlineTokenizer:
    """
    Apply bpe tokenization to str line
    """
    def __init__(self, bpe_dropout_rate, merge_table, random_seed=None):
        """
        :param bpe_dropout_rate: float [0,1)
        :param merge_table: dict [(token_1, token_2)] -> priority
        """
        self.random_generator = np.random.RandomState(random_seed)
        self.bpe_dropout_rate = bpe_dropout_rate
        self.merge_table = merge_table

    def __call__(self, line, **args):
        """
        :param line: str
        :return:
        """
        return tokenize_text(self.merge_table, line, self.bpe_dropout_rate, self.random_generator, **args)
    

class BpeOnlineParallelApplier:
    """
    Apply bpe online to data in parallel
    """
    def __init__(self, bpe_dropout_rates, merge_tables, random_seed=42):
        """
        :param bpe_dropout_rate: float [0,1)
        :param merge_table: dict [(token_1, token_2)] -> priority
        """
        assert len(bpe_dropout_rates) == len(merge_tables)
        self.bpe_appliers = []
        for rate, table in zip(bpe_dropout_rates, merge_tables):
            if table is not None:
                self.bpe_appliers.append(BpeOnlineTokenizer(rate, table, random_seed))
            else:
                self.bpe_appliers.append(lambda x: x)

    def __call__(self, lines):
        assert len(self.bpe_appliers) == len(lines)
        return tuple(applier(l) for applier, l in zip(self.bpe_appliers, lines))
