from transformers import AutoTokenizer


def load_tokenizer(model_id, alien_tokenizer_path):
    original_tokenizer = AutoTokenizer.from_pretrained(model_id)
    alien_tokenizer = AutoTokenizer.from_pretrained(alien_tokenizer_path)

    return original_tokenizer, alien_tokenizer


def build_translator(original_tokenizer, alien_tokenizer):
    class Translator(object):
        def __init__(self, original_tokenizer, alien_tokenizer):
            self.original_tokenizer = original_tokenizer
            self.alien_tokenizer = alien_tokenizer

        def encode(self, text):
            original_token_ids = self.original_tokenizer.encode(text)
            alien_text = self.alien_tokenizer.decode(original_token_ids)
            return alien_text
        
        def decode(self, text):
            alien_token_ids = self.alien_tokenizer.encode(text)
            original_text = self.original_tokenizer.decode(alien_token_ids)
            return original_text

    return Translator(original_tokenizer, alien_tokenizer)

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument('--original_tokenizer_path', type=str, default='meta-llama/Meta-Llama-3-8B-Instruct')
    parser.add_argument('--alien_tokenizer_path', type=str, default='/workspace/codes/AlienLM/alien_tokenizer/alien/full')
    args = parser.parse_args()

    original_tokenizer, alien_tokenizer = load_tokenizer(args.original_tokenizer_path, args.alien_tokenizer_path)
    translator = build_translator(original_tokenizer, alien_tokenizer)

    encoded_text = translator.encode("Hello, world!")
    print(f"encoded_text: {encoded_text}")
    print(f"decoded_text: {translator.decode(encoded_text)}")
    