import clip
import torch
import numpy as np
import pandas as pd
from models import CLIPModel

SIMPLE_IMAGENET_TEMPLATES = (
    lambda c: f"itap of a {c}.",
    lambda c: f"a bad photo of the {c}.",
    lambda c: f"a origami {c}.",
    lambda c: f"a photo of the large {c}.",
    lambda c: f"a {c} in a video game.",
    lambda c: f"art of the {c}.",
    lambda c: f"a photo of the small {c}.",
)

def get_prompt(words, index, device="cuda"):
    prompt_texts = [SIMPLE_IMAGENET_TEMPLATES[index](w) for w in words]
    return clip.tokenize(prompt_texts, truncate=True).to(device)

def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dataset = "CIFAR-10"
    nouns = pd.read_csv("./data/WordNetNouns.csv", header=0).values[:, 0]
    selected_idx = np.loadtxt("./data/" + dataset + "_nouns_list.txt", dtype=int)
    nouns = nouns[selected_idx]
    nouns_num = len(nouns)
    print(f"Loaded {nouns_num} selected nouns.")

    batch_size = 2048
    model = CLIPModel(model_name="ViT-B/32").to(device)
    model.eval()

    for idx in range(len(SIMPLE_IMAGENET_TEMPLATES)):
        features = []
        print(f"Inferring features for prompt #{idx}")
        for i in range((nouns_num + batch_size - 1) // batch_size):
            start = i * batch_size
            end = min(start + batch_size, nouns_num)
            batch_words = nouns[start:end]
            with torch.no_grad():
                text_tokens = get_prompt(batch_words, idx, device=device)
                emb = model.encode_text(text_tokens)
                features.append(emb.cpu().numpy())
            if (i + 1) % 10 == 0 or end == nouns_num:
                print(f"  Processed {end}/{nouns_num}")
        features = np.vstack(features)
        print(f"  Prompt {idx} feature shape: {features.shape}")
        out_file = f"./data/{dataset}_nouns_embedding_prompt_{idx}_selected.npy"
        np.save(out_file, features)
        print(f"  Saved to {out_file}")

if __name__ == "__main__":
    main()
