import os, sys
import torch
from clip import clip
from clip.simple_tokenizer import SimpleTokenizer

# 1. Load CLIP model and tokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"
model, _ = clip.load("ViT-B/16", device=device)
model.eval()
tokenizer = SimpleTokenizer()
embedding_matrix = model.token_embedding.weight  # [vocab_size, embed_dim]

# 2. File path to the text input
text_file_path = sys.argv[1]  # Provide with actual path
outfile = sys.argv[2]

# 3. Read lines
with open(text_file_path, "r", encoding="utf-8") as f:
    lines = f.readlines()

all_avg_embeddings = []  # [num_lines, embed_dim]
all_top_embeddings = []
for line in lines:
    groups = line.strip().split(",")
    token_lists = []

    # Collect tokenized groups: [[t1, t2, ...], [t1, t2], ...]
    for group in groups:
        tokens = group.strip().split()
        if tokens:
            token_lists.append([tok.lower() for tok in tokens])

    if len(token_lists) < 4:
        continue  # Skip lines with insufficient groups

    token_lists = token_lists[:4]  # Ensure only 4 groups

    max_token_len = max(len(toks) for toks in token_lists)

    embeddings_by_position = []

    for i in range(max_token_len):
        words_at_i = []
        for tokens in token_lists:
            if len(tokens) > i:
                words_at_i.append(tokens[i])
        if len(words_at_i) < 4:
            continue  # Ensure we only process if we have 4 words

        word_embeddings = []
        for word in words_at_i:
            token_ids = tokenizer.encode(word)
            embeddings = embedding_matrix[token_ids]  # [num_tokens, embed_dim]
            word_embedding = embeddings.mean(dim=0)   # average if multiple tokens
            word_embeddings.append(word_embedding)

        stacked = torch.stack(word_embeddings)  # [4, embed_dim]
        embeddings_by_position.append(stacked)

    if not embeddings_by_position:
        continue

    all_top_embeddings.append(embeddings_by_position[0])
    # Average over all [4, embed_dim] embeddings
    avg_embedding = torch.stack(embeddings_by_position).mean(dim=0)
    all_avg_embeddings.append(avg_embedding)

# 4. Convert to tensor
top_embeddings_tensor = torch.stack(all_top_embeddings)  # [num_lines, embed_dim]
print("Final top embeddings shape:", top_embeddings_tensor.shape)
avg_embeddings_tensor = torch.stack(all_avg_embeddings)  # [num_lines, embed_dim]
print("Final average embeddings shape:", avg_embeddings_tensor.shape)

# 5. Save
torch.save({
    "avg_embeddings": avg_embeddings_tensor,
    "top_embeddings": top_embeddings_tensor,
}, f"{outfile}.pt")
