import re
from data.tokenizer import Tokenizer

class WordleTokenizer(Tokenizer):
    def __init__(self):
        self.special_vocab = ['<g>', '<b>', '<y>', '<|pad|>', '</a>', '</s>', '<s>', '<a>', '</eod>']
        self.vocab = list('abcdefghijklmnopqrstuvwxyz') + self.special_vocab
        self.t2i = {w: i for i, w in enumerate(self.vocab)}
        super().__init__(self.token_to_id('<|pad|>'), 
                         self.token_to_id('</s>'), 
                         self.token_to_id('</a>'), 
                         self.token_to_id('<s>'), 
                         self.token_to_id('<a>'), 
                         self.token_to_id('</eod>'))
    
    def encode(self, str_, **kwargs):
        if isinstance(str_, str):
            special_idxs = []
            for special_char in self.special_vocab:
                special_idxs += list(map(lambda x: (x.start(), x.end(), self.token_to_id(special_char)), re.finditer(re.escape(special_char), str_)))
            special_idxs.sort(key=lambda x: x[0])
            tokens = []
            curr = 0
            for s, e, tok in special_idxs:
                tokens.extend([self.token_to_id(c) for c in str_[curr:s]])
                tokens.append(tok)
                curr = e
            tokens.extend([self.token_to_id(c) for c in str_[curr:]])
            return tokens, [int(t != self.pad_token_id) for t in tokens]
        elif isinstance(str_, list):
            tokens, pads = zip(*[self.encode(item) for item in str_])
            max_len = max(map(len, tokens))
            return [list(item)+([self.pad_token_id]*(max_len-len(item))) for item in tokens], [list(item)+([0]*(max_len-len(item))) for item in pads]
        else:
            raise ValueError('str_ must be a string or a list of strings')
    
    def decode(self, tokens, **kwargs):
        if len(tokens) == 0:
            return ''
        if not isinstance(tokens[0], list):
            return ''.join([self.id_to_token(item) for item in tokens])
        elif isinstance(tokens[0], list):
            return [self.decode(item) for item in tokens]
        else:
            raise ValueError('tokens must be a list of ints or a list of lists of ints')
    
    def num_tokens(self):
        return len(self.vocab)
    
    def id_to_token(self, id_):
        return self.vocab[id_]
    
    def token_to_id(self, token):
        return self.t2i[token]
    
    def get_vocab(self):
        return self.vocab
