import os
import sys
import argparse
import torch
from safetensors.torch import load_file, save_file

from clip.simple_tokenizer import SimpleTokenizer
from clip import clip

# "ViT-B/16"
# "RN50"
def load_clip_to_cpu_org(backbone_name="ViT-L/14"):
    url = clip._MODELS[backbone_name]
    model_path = clip._download(url, clip._MODEL_CACHE)

    try:
        # loading JIT archive
        model = torch.jit.load(model_path, map_location="cpu").eval()
        state_dict = None

    except RuntimeError:
        state_dict = torch.load(model_path, map_location="cpu")

    model = clip.build_model(state_dict or model.state_dict())

    return model

def load_clip_to_cpu(backbone_name="ViT-L/14"):
    model, _ = clip.load(backbone_name, device="cpu")  # Directly load the model to CPU
    model.eval()  # Ensure evaluation mode
    return model

# parser = argparse.ArgumentParser()
# parser.add_argument("fpath", type=str, help="Path to the learned prompt")
# parser.add_argument("topk", type=int, help="Select top-k similar words")
# args = parser.parse_args()

topk = 10

print(f"Return the top-{topk} matched words")

tokenizer = SimpleTokenizer()
clip_model = load_clip_to_cpu()
token_embedding = clip_model.token_embedding.weight
print(f"Size of token embedding: {token_embedding.shape}")
words = [tokenizer.decoder[idx].replace("</w>", "") for idx in range(token_embedding.shape[0])]
# print(words[49300:49406])
# Save both words and embeddings in a dictionary
save_data = {
    "words": words[512:49406],
    "embeddings": token_embedding[512:49406, :]
}

# # Save to .pt file
# torch.save(save_data, 'scripts/interpret_prompts/clip_words.pt')
total_common_top5 = 0
total_common_top10 = 0
num_indices = 66
start_idx = 3001
# fpath = sys.argv[1]
for idx in range(start_idx, start_idx+num_indices):
    paths = [f'/data/user/diffusers/examples/textual_inversion/variations_identity_slider/identity-{idx}-prompts/diff_learned_embeds.safetensors', f'/data/user/diffusers/examples/textual_inversion/sd_prompts/identity-{idx}-prompts/learned_embeds.safetensors']
    words_list = []
    for fpath in paths:
        assert os.path.exists(fpath)
        ckpt = load_file(fpath)
        for k, v in ckpt.items():
            ctx = ckpt[k]
        ctx = ctx.float()
        # print(f"Size of context: {ctx.shape}")
        all_layer_ctx = [ctx] 

        for idx, single_ctx in enumerate(all_layer_ctx):
            # print("SHOWING RESULTS FOR CTX Vectors of Layer: ", idx + 1)
            ctx = single_ctx
            if ctx.dim() == 2:
                # Generic context
                distance = torch.cdist(ctx, token_embedding)
                # print(f"Size of distance matrix: {distance.shape}")
                sorted_idxs = torch.argsort(distance, dim=1)
                sorted_idxs = sorted_idxs[:, :topk]

                for m, idxs in enumerate(sorted_idxs):
                    words = [tokenizer.decoder[idx.item()] for idx in idxs]
                    dist = [f"{distance[m, idx].item():.4f}" for idx in idxs]
                    words_list.append(words)
                    print(f"{m+1}: {words} {dist}")

            elif ctx.dim() == 3:
                # Class-specific context
                raise NotImplementedError

            print("##############################")
            print("##############################")
     # Compute common words between the two sets
    common_top5 = len(set(words_list[0][0][:5]) & set(words_list[1][0][:5]))
    common_top10 = len(set(words_list[0][0][:10]) & set(words_list[1][0][:10]))

    total_common_top5 += common_top5
    total_common_top10 += common_top10

# Compute averages
avg_common_top5 = total_common_top5 / num_indices
avg_common_top10 = total_common_top10 / num_indices

print(f"Average common words in top-5: {avg_common_top5:.2f}")
print(f"Average common words in top-10: {avg_common_top10:.2f}")