from transformers import AutoModelForCausalLM, AutoTokenizer
from Efficient_Token_Matcher import OptimizedTokenMatcher
from tqdm import tqdm
from random import shuffle
import re
import json 
import torch

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
proxy_model_id = "Qwen/Qwen2.5-7B-Instruct"

model = AutoModelForCausalLM.from_pretrained(proxy_model_id, cache_dir = '/workspace/CACHE/MODELS')
proxy_lm_head = model.lm_head.weight.detach()
proxy_tokenizer = AutoTokenizer.from_pretrained(proxy_model_id, cache_dir = '/workspace/CACHE/MODELS')
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir = '/workspace/CACHE/MODELS')

lm_head = []
vocab = tokenizer.get_vocab()
vocab = sorted(vocab.items(), key=lambda x: x[1])

non_matched = 0
non_matched_len = []
for tok, tok_id in tqdm(vocab):
    proxy_tok_id = proxy_tokenizer.encode(tok)
    proxy_tok_embedding = proxy_lm_head[proxy_tok_id]
    if len(proxy_tok_id) > 1:
        non_matched += 1
        non_matched_len.append(len(proxy_tok_id))

        proxy_tok = proxy_tokenizer.convert_ids_to_tokens(proxy_tok_id)
        proxy_tok_len = [len(tok) for tok in proxy_tok]
        proxy_tok_prop = [tok_len / sum(proxy_tok_len) for tok_len in proxy_tok_len]
        proxy_tok_embedding = [proxy_tok_embedding[i] * proxy_tok_prop[i] for i in range(len(proxy_tok_id))]
        proxy_tok_embedding = torch.stack(proxy_tok_embedding).sum(dim=0).unsqueeze(0)
    lm_head.append(proxy_tok_embedding)
lm_head = torch.stack(lm_head).squeeze(1)

print(f">>> {non_matched} tokens are not matched")
print(f">>> max token length: {max(non_matched_len)}")
print(f">>> avg token length: {sum(non_matched_len) / len(non_matched_len)}")

token_freq = json.load(open('/workspace/codes/AlienLMv2/alien_tokenizer/token-freq/result/Meta-Llama-3-8B-Instruct/pro_tok_dict.json', 'r')) # token frequency file for the training corpus

tokens = tokenizer.get_vocab()
tokens = [(tokenizer.convert_ids_to_tokens(idx), idx) for tok, idx in tokens.items()]

tokens = sorted(tokens, key=lambda x: token_freq.get(x[0], 0), reverse=True)

added_vocab = tokenizer.get_added_vocab()
total_special_tokens = {tok for tok, _id in added_vocab.items()}

# delete special tokens
tokens = [token for token in tokens if token[0] not in total_special_tokens]

embeddings = lm_head[[token[1] for token in tokens]]
id_to_idx = {token[1]: idx for idx, token in enumerate(tokens)}
idx_to_id = {idx: token[1] for idx, token in enumerate(tokens)}
tokens = [(token[0], id_to_idx[token[1]]) for token in tokens]

def main():
    import faiss
    faiss.omp_set_num_threads(64)

    enlglish_matcher = OptimizedTokenMatcher(
        embeddings=embeddings,
        tokens=tokens,
        id_to_idx=id_to_idx,
        idx_to_id=idx_to_id,
        batch_size=100,
        n_neighbors=50
    )
    
    matches = enlglish_matcher.find_matches(lev_weight=1, sim_weight=0.01)
    # save matches
    with open('matches-sim-and-diff.txt', 'w') as f:
        for match in matches:
            f.write(f"{match[0]}\t{match[1]}\t{match[2]}\n")
if __name__ == "__main__":
    main()