from typing import Dict, List


class BPETokenizerTrainer(object):
    def __init__(self, training_set: str, max_num_merges: int, max_num_tokens: int) -> None:
        self.max_num_merges = max_num_merges
        self.last_token_id = 0
        self.max_num_tokens = max_num_tokens

        self.training_set_symbolized: List[str] = []
        self.lookup_table: Dict[str, int] = {}
        for char in training_set:
            if len(self.lookup_table) >= self.max_num_tokens:
                break
            self.training_set_symbolized.append(char)
            if char not in self.lookup_table:
                self.lookup_table[char] = self.last_token_id
                self.last_token_id += 1
    
    def merge(self, new_token_text: str) -> None:
        new_symbol = new_token_text
        new_training_set_symbolized: List[str] = []
        i = 1
        while i < len(self.training_set_symbolized):
            pair_text = self.training_set_symbolized[i-1] + self.training_set_symbolized[i]
            if pair_text == new_token_text:
                new_training_set_symbolized.append(new_symbol)
                i += 1
                if i == len(self.training_set_symbolized) - 1:
                    new_training_set_symbolized.append(self.training_set_symbolized[i])
            else:
                new_training_set_symbolized.append(self.training_set_symbolized[i-1])
                if i == len(self.training_set_symbolized) - 1:
                    new_training_set_symbolized.append(self.training_set_symbolized[i])
            i += 1
        self.training_set_symbolized = new_training_set_symbolized
    
    def add_next_pair(self) -> None:
        pair_counts: Dict[str, int] = {}
        i = 1
        while i < len(self.training_set_symbolized):
            pair_text = self.training_set_symbolized[i-1] + self.training_set_symbolized[i]
            if pair_text not in pair_counts:
                pair_counts[pair_text] = 1
            else:
                pair_counts[pair_text] += 1
            i += 1
        
        most_common_pair_text = max(pair_counts, key=pair_counts.get)
        self.lookup_table[most_common_pair_text] = self.last_token_id
        self.last_token_id += 1
        self.merge(new_token_text=most_common_pair_text)
            
    
    def train(self) -> None:
        num_merges = 0
        while num_merges < self.max_num_merges and len(self.training_set_symbolized) > 1 and len(self.lookup_table) < self.max_num_tokens:
            self.add_next_pair()
            num_merges += 1
    
    def get_lookup_table(self) -> Dict[str, int]:
        return self.lookup_table