import os
import sys
import argparse
import torch
from safetensors.torch import load_file, save_file

from clip.simple_tokenizer import SimpleTokenizer
from clip import clip

# "ViT-B/16"
# "RN50"
def load_clip_to_cpu_org(backbone_name="ViT-L/14"):
    url = clip._MODELS[backbone_name]
    model_path = clip._download(url, clip._MODEL_CACHE)

    try:
        # loading JIT archive
        model = torch.jit.load(model_path, map_location="cpu").eval()
        state_dict = None

    except RuntimeError:
        state_dict = torch.load(model_path, map_location="cpu")

    model = clip.build_model(state_dict or model.state_dict())

    return model

def load_clip_to_cpu(backbone_name="ViT-L/14"):
    model, _ = clip.load(backbone_name, device="cpu")  # Directly load the model to CPU
    model.eval()  # Ensure evaluation mode
    return model

# parser = argparse.ArgumentParser()
# parser.add_argument("fpath", type=str, help="Path to the learned prompt")
# parser.add_argument("topk", type=int, help="Select top-k similar words")
# args = parser.parse_args()

fpath = sys.argv[1]
topk = 10

assert os.path.exists(fpath)

print(f"Return the top-{topk} matched words")

tokenizer = SimpleTokenizer()
clip_model = load_clip_to_cpu()
token_embedding = clip_model.token_embedding.weight
print(f"Size of token embedding: {token_embedding.shape}")
words = [tokenizer.decoder[idx].replace("</w>", "") for idx in range(token_embedding.shape[0])]
print(words[49300:49406])
# Save both words and embeddings in a dictionary
save_data = {
    "words": words[512:49406],
    "embeddings": token_embedding[512:49406, :]
}

# # Save to .pt file
# torch.save(save_data, 'scripts/interpret_prompts/clip_words.pt')
ckpt = load_file(fpath)
for k, v in ckpt.items():
    ctx = ckpt[k]
ctx = ctx.float()
print(f"Size of context: {ctx.shape}")

all_layer_ctx = [ctx] 

for idx, single_ctx in enumerate(all_layer_ctx):
    print("SHOWING RESULTS FOR CTX Vectors of Layer: ", idx + 1)
    ctx = single_ctx
    if ctx.dim() == 2:
        # Generic context
        distance = torch.cdist(ctx, token_embedding)
        print(f"Size of distance matrix: {distance.shape}")
        sorted_idxs = torch.argsort(distance, dim=1)
        sorted_idxs = sorted_idxs[:, :topk]

        for m, idxs in enumerate(sorted_idxs):
            words = [tokenizer.decoder[idx.item()] for idx in idxs]
            dist = [f"{distance[m, idx].item():.4f}" for idx in idxs]
            print(f"{m+1}: {words} {dist}")

    elif ctx.dim() == 3:
        # Class-specific context
        raise NotImplementedError

    print("##############################")
    print("##############################")