import tiktoken
import re
import transformers

def replace_content(template, content_list):
    """
    In the template, the content should be represented as |<content>|. 
    """
    content_list = content_list.copy()
    return re.sub(r'\|<content>\|', lambda x: content_list.pop(0), template)

def target_tokenizer_function_models(model_class, auth_token=None):

    if model_class == 'gpt-3.5':
        def target_tokenizer_function(message):
            tokenizer = tiktoken.encoding_for_model('gpt-3.5-turbo')
            return [t.decode('UTF-8') for t in tokenizer.decode_tokens_bytes(tokenizer.encode(message))]
        return target_tokenizer_function

    if model_class == 'gpt-2':
        def target_tokenizer_function(message):
            tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2', local_files_only=True)
            return [tokenizer.convert_tokens_to_string([t]) for t in tokenizer.tokenize(message)]
        return target_tokenizer_function
    
    if model_class == 'opt':
        def target_tokenizer_function(message):
            tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained_model_name_or_path='facebook/opt-1.3b')
            return [tokenizer.convert_tokens_to_string([t]) for t in tokenizer.tokenize(message)]
        return target_tokenizer_function

    if model_class == 'llama-2':
        def target_tokenizer_function(message):
            if auth_token is None:
                raise ValueError('auth_token should be provided for llama-2')
            tokenizer = transformers.LlamaTokenizerFast.from_pretrained('meta-llama/Llama-2-7b-chat-hf', use_auth_token=auth_token)
            # llama-2 will add a special token at the begining of the sentence, we need remove it.
            # The convert_tokens_to_string: remove the special begining tokens, remove the placeholder of whitespace
            # Don't know whether the removal of whitespace will cause bugs.
            tokens = tokenizer.tokenize(message)
            if tokens[0] == '▁':
                tokens = tokens[1:]
            else:
                tokens[0] = tokens[0][1:] # some special operation for the first token
            convert = lambda token: token.replace('▁', ' ').replace('<0x0A>', '\n')
            return [convert(t) for t in tokens] 
        return target_tokenizer_function

    raise NotImplementedError

def test_positions(template, content_list, target_tokenize_function, add_eot=True, split_punctuations=[]):
    """
    In the template, the content should be represented as |<content>|. 
    'target_tokenize_function' should be a callable function, which takes a string as input and return a list of string.
    If add_eot, the return test positions will also contain the end of the sentence.
    Notice that we always omit the first token position because it should never be predicted.
    split_punctuations: whether to split the punctuations from the words. If not None, the listed punctuations will be splitted from the words as a new token. Notice that it could cause bugs for some tokenizers (e.g. GPT-3.5). 
    -------------
    output[0]: test_positions: dict{'input': list[int], 'output': list[int]}, the first idx ofe each word, the 'output' is the 'input' minus 1.
    output[1]: manually labeled t/c token: np.array, shape: (num_words,), 0: content, 1: template
    output[2]: our_tokenized: list[str], words list.
    output[3]: target_tokenized: list[str], the target tokenization, should be sub-word.
    """
    # Our basic assumption is that the tokenization is sub-word, which means the tokenizer cannot join parts from two words into one token.
    # Here we assume the tokenizer of GPT-3.5-turbo is a most coarse tokenizer, we based on it and further concate its tokens.
    gpt_tokenizer = tiktoken.encoding_for_model('gpt-3.5-turbo')

    part_tokens = []
    for part in re.split(r'\|<content>\|| \|<content>\|', template):
        tokens = [t.decode('UTF-8') for t in gpt_tokenizer.decode_tokens_bytes(gpt_tokenizer.encode(part))]

        # concat tokens without whitespace at begining into the previous token
        concat_pos = []
        for i, t in enumerate(tokens):
            if i == 0:
                continue
            if (not re.match(r' +', t)): # without whitespace at begining
                if t[0] not in split_punctuations: # split_punctuations as new tokens.
                    concat_pos.append(i)
        concat_pos = concat_pos[::-1] # reverse to achieve continue concate
        for pos in concat_pos:
            tokens[pos-1] = tokens[pos-1] + tokens[pos]
        concat_tokens = [t for i, t in enumerate(tokens) if i not in concat_pos]
        part_tokens.append(concat_tokens)

    content_special_tokens = re.findall(r'\|<content>\|| \|<content>\|', template)
    for i, content in enumerate(content_list):
        if content_special_tokens[i] == ' |<content>|':
            content_tokens = [t.decode('UTF-8') for t in gpt_tokenizer.decode_tokens_bytes(gpt_tokenizer.encode(' '+content))]
            assert not content_tokens[0] == ' ' # the whitespace should be merged into the first content token
    # give our tokenized sentence
    our_tokenized = []
    for i, tokens in enumerate(part_tokens):
        our_tokenized += tokens
        if i != len(part_tokens) - 1:
            our_tokenized.append(content_special_tokens[i])

    # check the tokenization
    assert ''.join(our_tokenized) == template

    # len of our tokens
    len_of_tokens = []
    tc = []
    content_id = 0
    for token in our_tokenized:
        if token == ' |<content>|':
            tc.append(0)
            len_of_tokens.append(len(target_tokenize_function(' '+content_list[content_id])))
            content_id += 1
        elif token == '|<content>|':
            len_of_tokens.append(len(target_tokenize_function(content_list[content_id])))
            content_id += 1
            tc.append(0)
        else:
            len_of_tokens.append(len(target_tokenize_function(token)))
            tc.append(1)

    # test positions should be first token id for our tokens.
    import numpy as np
    test_positions = [0] + np.cumsum(len_of_tokens).tolist()

    # check whether the target tokenizer meet our basic assumption, which means the tokenization is sub-word.
    complete_sentence = replace_content(template, content_list)
    target_tokenized = target_tokenize_function(complete_sentence) # the actual tokenization of the target tokenizer.
    content_id = 0
    for i in range(len(test_positions) - 1):
        if our_tokenized[i] != '|<content>|' and our_tokenized[i] != ' |<content>|':
            if not ''.join(target_tokenized[test_positions[i]: test_positions[i+1]]) == our_tokenized[i]:
                raise ValueError('The target tokenizer does not meet our basic assumption, which means the tokenization is sub-word.\nThe target tokenization: {}\nour tokenization: {}'.format(target_tokenized[test_positions[i]: test_positions[i+1]], our_tokenized[i]))
        else:
            if not ''.join(target_tokenized[test_positions[i]: test_positions[i+1]]) == (content_list[content_id] if our_tokenized[i] == '|<content>|' else ' '+content_list[content_id]):
                raise ValueError('The target tokenizer does not meet our basic assumption, which means the tokenization is sub-word.\nThe target tokenization: {}\nour tokenization: {}'.format(target_tokenized[test_positions[i]: test_positions[i+1]], (content_list[content_id] if our_tokenized[i] == '|<content>|' else ' '+content_list[content_id])))
            content_id += 1
    if not add_eot:
        test_positions = test_positions[:-1]
    else:
        tc += [1] # the |<end_of_text>| token should be considered as a template.
        our_tokenized.append('|<eot>|')
        target_tokenized.append('|<eot>|')
    test_positions = {'input': test_positions[1:], 'output': [(i-1) for i in test_positions[1:]]}
    tc = tc[1:]
    return test_positions, np.array(tc), our_tokenized, target_tokenized

def show_test_positions(output_of_test_positions):
    pos = output_of_test_positions[0]['input']
    tc = output_of_test_positions[1]
    our_tokenized = output_of_test_positions[2]
    target_tokenized = output_of_test_positions[3]
    pos = pos + [len(target_tokenized)] # do not change the pos list
    for i in range(len(pos)-1):
        print('Test position {:>3d}: {:>3d}. '.format(i, pos[i]), end='')
        print('The word is: {:>15} ({}). '.format(repr(our_tokenized[i+1]), 'T' if tc[i] == 1 else 'C'), end='')
        print('The target tokenization is: {}. '.format(target_tokenized[pos[i]:pos[i+1]]))

def detect_prompt_tc(output_of_test_positions, template, len_of_prompt):
    """
    Given a sentence with content special tokens |<content>|, detect the prompt t/c tokens.
    --------------
    output_of_test_positions: the output of test_positions function, where the sentence should not contain content special token |<content>| (which means a replaced sentences).
    template: the template of the sentence, i.e., dataset['template'].
    len_of_prompt: the length of the prompt string.
    """
    prompt_tc = []
    our_tokens = output_of_test_positions[2]
    pos = 0
    for token in our_tokens:
        if pos == len_of_prompt:
            break
        if pos > len_of_prompt:
            raise ValueError('The prompt string contains incomplete tokens.')
        len_of_token = len(token)
        if template[pos:pos+11] == '|<content>|':
            prompt_tc.append(0)
            pos += 11
        elif template[pos:pos+12] == ' |<content>|':
            prompt_tc.append(0)
            pos += 12
        elif template[pos:pos+len_of_token] == token:
            prompt_tc.append(1)
            pos += len_of_token
        else:
            raise ValueError('The tokenization of the template and our tokens do not match.')
    return prompt_tc
    
    

if __name__ == '__main__':

    # test whether the target tokenizer meet our basic assumption, which means the tokenization is sub-word.
    import json
    with open('concat_dataset/dataset_0.json') as f:
        dataset = json.load(f)

    with open('huggingface_auth_token') as f:
        auth_token = f.read().strip()

    model_list = ['gpt-3.5', 'gpt-2', 'llama-2']

    for model_name in model_list:
        test_positions(dataset['template'], dataset['content_list'][0], target_tokenizer_function_models(model_name, auth_token=auth_token))
        print(f'Congratulations! The tokenizer of model {model_name} satisfys our assumption!')