import json
import os
import sys
import torch
import numpy as np
from tqdm import tqdm
from transformers import AutoConfig
from collections import Counter
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from src.model_load import load_tokenizer, load_model_only
from src.transfer_matrix.common_vocabulary import CommonVocabulary


def compute_tokenizer_based_reverse_mapping(source_tokenizer, target_tokenizer, source_vocab_size, target_vocab_size, 
                                    common_vocab_indices_source, common_vocab_indices_target, unused_cols=None, device="cuda", flag=False):
    source_vocab = {idx: token for token, idx in source_tokenizer.get_vocab().items()}
    target_vocab = {idx: token for token, idx in target_tokenizer.get_vocab().items()} 
    source_vocab_decoded = {idx: source_tokenizer.decode([idx]) for token, idx in source_tokenizer.get_vocab().items()}
    target_vocab_decoded = {idx: target_tokenizer.decode([idx]) for token, idx in target_tokenizer.get_vocab().items()}
    common_vocab_indices_target_dict = {idx: 0 for idx in common_vocab_indices_target}
    all_source_indices = torch.arange(source_vocab_size, device=device)
    mask = torch.ones_like(all_source_indices, dtype=torch.bool)
    mask[common_vocab_indices_source] = False  
    if unused_cols is not None:
        non_common_indices = unused_cols
    else:
        non_common_indices = all_source_indices[mask].cpu().numpy()  

    print(f"Computing tokenizer-based mapping for {len(non_common_indices)} non-common source tokens...")
    row_indices = []
    col_indices = []
    values = []

    if flag:
        row_indices.extend(common_vocab_indices_source)
        col_indices.extend(common_vocab_indices_target)
        values.extend([1.0] * len(common_vocab_indices_source))

    MOJIBAKE_PATTERN = (
            r"[�ÃÂÊÐÊÌÍÎÏÐÑÒÓÔÕ×ØÙÚÛÜÝÞßàáâãäåæçèéêëìíîïðñòóôõöøùúûüýþÿ]"  
            r"|[\u200B-\u200D\uFEFF]"  
            r"|[\uE000-\uF8FF]"  
            )
    num = 0
    num_equal = 0
    num_none = 0
    num_list = []
    for source_idx in tqdm(non_common_indices, desc="Processing non-common tokens"):
        # source_token = clean_token(source_tokenizer.decode([source_idx])) #source_vocab.get(source_idx, None) 
        source_token = source_tokenizer.decode([source_idx])

        # source_token_test = source_tokenizer.convert_ids_to_tokens([source_idx], skip_special_tokens=True)[0]

        # if source_token_test!=source_token:
        #     print(source_token_test, source_token)

        # source_token_ = source_vocab.get(source_idx, None)
        # if source_token!=source_token_:
        #     num+=1
        #     print(source_token, source_token_)
        if source_token is None:
            continue

        target_tokenized = target_tokenizer.encode(source_token, add_special_tokens=False)
        num_list.append(len(target_tokenized))            
        if len(target_tokenized)==1:
            num_equal+=1
        if len(target_tokenized)==0:
            num_none+=1
            continue
        target_idx = target_tokenized[0]  
        target_token = target_tokenizer.decode([target_idx]).strip()
        stripped_token = source_token.strip()  


        if stripped_token: 
            if re.search(MOJIBAKE_PATTERN, stripped_token[0]):
                num+=1
                continue
        if target_token:
            if re.search(MOJIBAKE_PATTERN, target_token[0]):
                num+=1
                continue
             
        # print("Target tokenized:", target_tokenized)
        # print("Target index:", target_idx)
        # print("Target vocab lookup:", target_vocab.get(target_idx, None))
        # print("Source vocab lookup:", source_vocab.get(source_idx, None))
        # print("Decoded target token:", target_tokenizer.decode([target_idx]))
        # print("Decoded source token:", source_token)
        # if target_idx in common_vocab_indices_target_dict:
        #     num_none+=1
        #     continue
        row_indices.append(source_idx)
        col_indices.append(target_idx)
        values.append(1.0)
    print(num)
    unique_cols, counts = np.unique(col_indices, return_counts=True)
    print(len(unique_cols))
    # col_indices, row_indices, values = select_shortest_token_with_indices(col_indices, row_indices, source_vocab_decoded, target_vocab_decoded)
    # Normalize
    # unique_cols, counts = np.unique(col_indices, return_counts=True)
    # col2count = dict(zip(unique_cols, counts))
    # print([values for key, values in col2count.items() if values>1 ], sum([values for key, values in col2count.items() if values==1 ]), sum([values for key, values in col2count.items() if values==1 and key in common_vocab_indices_target_dict]))
    # print(len(non_common_indices), len([key for key, values in col2count.items() if values>1 and key in common_vocab_indices_target_dict]))
    # # Now divide each value by the count of its col_index
    # for i in range(len(values)):
    #     c = col_indices[i]
    #     values[i] = values[i] / col2count[c]
    sparse_matrix = torch.sparse_coo_tensor(
        indices=torch.tensor([col_indices, row_indices], device=device),
        values=torch.tensor(values, dtype=torch.float32, device=device),
        size=(target_vocab_size, source_vocab_size)
    )
    return sparse_matrix.cpu()

from collections import defaultdict
from ftfy import fix_text
import unicodedata,re

def compute_tokenizer_based_mapping(source_tokenizer, target_tokenizer, source_vocab_size, target_vocab_size, 
                                    common_vocab_indices_source, common_vocab_indices_target, device="cuda", flag=True):


    source_vocab = {idx: token for token, idx in source_tokenizer.get_vocab().items()}
    target_vocab = {idx: token for token, idx in target_tokenizer.get_vocab().items()} 

    common_vocab_indices_target_dict = {idx:0 for idx in common_vocab_indices_target}
    all_source_indices = torch.arange(source_vocab_size, device=device)
    mask = torch.ones_like(all_source_indices, dtype=torch.bool)
    mask[common_vocab_indices_source] = False  
    non_common_indices = all_source_indices[mask]  

    print(f"Computing tokenizer-based mapping for {len(non_common_indices)} non-common source tokens...")

    row_indices = []
    col_indices = []
    values = []
    MOJIBAKE_PATTERN = (
            r"[�ÃÂÊÐÊÌÍÎÏÐÑÒÓÔÕ×ØÙÚÛÜÝÞßàáâãäåæçèéêëìíîïðñòóôõöøùúûüýþÿ]"  
            r"|[\u200B-\u200D\uFEFF]"  
            r"|[\uE000-\uF8FF]"  
            )


    if flag:
        row_indices.extend(common_vocab_indices_source)
        col_indices.extend(common_vocab_indices_target)
        values.extend([1.0] * len(common_vocab_indices_source))

    # sparse_matrix = torch.sparse_coo_tensor(
    #     indices=torch.tensor([row_indices, col_indices], device=device),
    #     values=torch.tensor(values, dtype=torch.float32, device=device),
    #     size=(source_vocab_size, target_vocab_size)
    # )
    # return sparse_matrix.cpu(), []

    token_mappings = defaultdict(list)
    num = 0
    num_equal = 0
    num_none = 0
    num_ = 0
    num_list = []
    test_dict = defaultdict(int)

    oov_id = source_tokenizer.unk_token_id
    for source_idx in tqdm(non_common_indices.cpu().numpy(), desc="Processing non-common tokens"):
        # source_token = clean_token(source_tokenizer.decode([source_idx])) #source_vocab.get(source_idx, None)
        source_token = source_tokenizer.decode([source_idx])
        if source_token is None:
            continue

        target_tokenized = target_tokenizer.encode(source_token, add_special_tokens=False)

        # print(target_tokenizer.convert_ids_to_tokens([29871], skip_special_tokens=True)[0])
        # print(target_tokenizer.encode("", add_special_tokens=False))
        
        # print(source_token, source_token_test, target_tokenized, target_tokenizer.convert_tokens_to_ids(subtokens_test))
        num_list.append(len(target_tokenized))

        if len(target_tokenized)==1:
            num_equal+=1
        if len(target_tokenized)==0:
            continue


        target_idx = target_tokenized[0]  

        target_token = target_tokenizer.decode([target_idx]).strip()
        stripped_token = source_token.strip() 


        if stripped_token:  
            if re.search(MOJIBAKE_PATTERN, stripped_token[0]):
                # print("source", stripped_token, source_tokenizer.convert_ids_to_tokens(int(source_idx), skip_special_tokens=True))
                num+=1
                continue
        if target_token:
            if re.search(MOJIBAKE_PATTERN, target_token[0]):
                # print("source", stripped_token, source_tokenizer.convert_ids_to_tokens(int(source_idx), skip_special_tokens=True))
                # print(target_tokenizer.tokenize(source_token))
                # print("target", target_token, target_tokenizer.convert_ids_to_tokens(int(target_idx), skip_special_tokens=True))
                num+=1
                continue

        # if target_idx==65746:
        #     num+=1
        #     print(source_token, source_vocab[source_idx])
        token_mappings[target_tokenizer.decode([target_idx])].append(source_token)
        if target_idx in common_vocab_indices_target_dict:
            num_none+=1

        # print("Target tokenized:", target_tokenized)
        # print("Target index:", target_idx)
        # print("Target vocab lookup:", target_vocab.get(target_idx, None))
        # print("Source vocab lookup:", source_vocab.get(source_idx, None))
        # with open("test.txt", "a+") as f:
        #     f.write(f"{source_token}->{target_tokenizer.decode([target_idx])}\n")
        row_indices.append(source_idx)
        col_indices.append(target_idx)
        values.append(1.0)

    # print(token_mappings)
    unique_cols, counts = np.unique(col_indices, return_counts=True)
    cols = list(range(target_vocab_size))
    unused_cols = list(set(cols) - set(unique_cols))
    print(len(unique_cols))
    # Normalize
    # unique_rows, counts = np.unique(row_indices, return_counts=True)
    # row2count = dict(zip(unique_rows, counts))

    # Now divide each value by the count of its col_index
    # for i in range(len(values)):
    #     c = row_indices[i]
    #     values[i] = values[i] / row2count[c]
    sparse_matrix = torch.sparse_coo_tensor(
        indices=torch.tensor([row_indices, col_indices], device=device),
        values=torch.tensor(values, dtype=torch.float32, device=device),
        size=(source_vocab_size, target_vocab_size)
    )
    return sparse_matrix.cpu(), unused_cols

if __name__ == "__main__":
    probability_transfer_matrix_save_path = sys.argv[1] + "/vis_tokenizer_"
    model_paths = sys.argv[2:]

    probability_transfer_matrix_name_list = [os.path.basename(model_path) for model_path in model_paths]
    # probability_transfer_matrix_save_path += "_".join(probability_transfer_matrix_name_list)
    probability_transfer_matrix_save_path += probability_transfer_matrix_name_list[0]
    print("probability_transfer_matrix_save_path:", probability_transfer_matrix_save_path)

    tokenizers = [load_tokenizer(model_path) for model_path in model_paths]
    vocab_lengths = [AutoConfig.from_pretrained(model_path, trust_remote_code=True).vocab_size for model_path in model_paths]
    
    # vocab_lengths = [len(tokenizer.get_vocab()) for tokenizer in tokenizers]
    print("Vocab sizes:", vocab_lengths)

    try:
        os.makedirs(probability_transfer_matrix_save_path)
    except FileExistsError:
        pass

    for index, model_path in enumerate(model_paths[1:], start=1):
        source_tokenizer = tokenizers[index]
        target_tokenizer = tokenizers[0]  


        w_1, w_2 = 1.0, 0.5
        if w_1==1.0:
            flag=True
        else:
            flag = False
        # flag = False
        tokenizer_based_sparse_matrix, unused_cols = compute_tokenizer_based_mapping(
            source_tokenizer, target_tokenizer, source_vocab_size, target_vocab_size, 
            common_vocab_indices_source, common_vocab_indices_target, device="cuda", flag=flag
        )
        if w_2==1.0:
            flag=True
        else:
            flag=False
        source_tokenizer, target_tokenizer, source_vocab_size, target_vocab_size, common_vocab_indices_source, common_vocab_indices_target = target_tokenizer, source_tokenizer, target_vocab_size, source_vocab_size, common_vocab_indices_target, common_vocab_indices_source
        tokenizer_based_reverse_sparse_matrix = compute_tokenizer_based_reverse_mapping(
            source_tokenizer, target_tokenizer, source_vocab_size, target_vocab_size, 
            common_vocab_indices_source, common_vocab_indices_target, device="cuda", flag=flag
        )
        
        tokenizer_based_sparse_matrix_final = w_1*tokenizer_based_sparse_matrix + w_2*tokenizer_based_reverse_sparse_matrix
        
        torch.save(tokenizer_based_sparse_matrix, 
                   os.path.join(probability_transfer_matrix_save_path, f"{os.path.basename(model_path)}_to_{os.path.basename(model_paths[0])}_tokenizer_forward_mapping.pth"))
        
        torch.save(tokenizer_based_reverse_sparse_matrix, 
                   os.path.join(probability_transfer_matrix_save_path, f"{os.path.basename(model_path)}_to_{os.path.basename(model_paths[0])}_tokenizer_reverse_mapping.pth"))
        
        torch.save(tokenizer_based_sparse_matrix_final, 
                   os.path.join(probability_transfer_matrix_save_path, f"{os.path.basename(model_path)}_to_{os.path.basename(model_paths[0])}_tokenizer_combined_{w_1}_{w_2}_mapping.pth"))

        print(f"✅ Tokenizer-based mapping computed and saved for {model_paths[index]} -> {model_paths[0]}")

    print("🎉 All tokenizer-based mappings saved successfully!")
