import os
import pickle
from tqdm import tqdm
import torch
import open_clip
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast
import submitit


# change your preferred clip model here
MODEL_NAME = "ViT-SO400M-14-SigLIP-384"
PRETRAINED_DATASET = "webli"
force_quick_gelu = False

NUM_WORKERS = 10

DATA_PATH = "../metadata/"


def main(complexity, batch_size=6144):
    annot_file = f"{DATA_PATH}/cc12m/full_dict_gemma3_eval_clean_4caps.pkl"

    output_folder = f"{DATA_PATH}/cc12m/clip_embeddings/"
    os.makedirs(output_folder, exist_ok=True)

    file_name = f"{MODEL_NAME}_{PRETRAINED_DATASET}_text_c{complexity}.npy"

    output_file = os.path.join(output_folder, file_name)

    print("Loading OpenCLIP model...")
    model, _, _ = open_clip.create_model_and_transforms(
        MODEL_NAME,
        pretrained=PRETRAINED_DATASET,
        force_quick_gelu=force_quick_gelu,
    )
    tokenizer = open_clip.get_tokenizer(MODEL_NAME)
    print("Model loaded.")

    if torch.cuda.is_available():
        device = torch.device("cuda")
        model.to(device)
    else:
        raise RuntimeError("CUDA is not available. Please check your setup.")

    class TextDataset(Dataset):
        def __init__(self, annot_file, tokenizer, complexity):
            with open(annot_file, 'rb') as f:
                self.image_dict = pickle.load(f)
            self.image_names = list(self.image_dict.keys())
            self.tokenizer = tokenizer
            self.complexity = complexity

        def __len__(self):
            return len(self.image_names)

        def __getitem__(self, idx):
            # Tokenize text
            img_id = self.image_names[idx]
            text = self.image_dict[img_id]["caps"][self.complexity]
            text = self.tokenizer(text)[0]
            return text

    print("Creating Dataset and DataLoader...")
    dataset = TextDataset(annot_file, tokenizer, complexity)
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        drop_last=False,
    )
    print("DataLoader created.")

    all_texts_features = []

    model.eval()

    print("Starting encoding...")
    with torch.no_grad():
        with autocast(dtype=torch.float16):
            for texts in tqdm(dataloader, desc="Encoding Batches"):
                # Skip empty batches if collate_fn filtered everything
                if texts.numel() == 0:
                    continue

                # Move data to the primary GPU.
                texts = texts.to(device, non_blocking=True)

                # Encode images and text
                texts_features = model.encode_text(texts)

                # Normalize features
                texts_features /= texts_features.norm(dim=-1, keepdim=True)

                # Collect features
                all_texts_features.append(texts_features.detach().cpu())

    print("Encoding finished.")

    print("Concatenating features...")
    texts_features_np = torch.cat(all_texts_features).numpy()

    print(f"Save image features ({texts_features_np.shape}) to {output_file}")
    np.save(output_file, texts_features_np)

    print("Done.")


if __name__ == "__main__":
    executor = submitit.AutoExecutor(folder="../logs/clip_embeddings_logs/")
    executor.update_parameters(
        timeout_min=12*60,
        mem_gb=10,
        name="txtclipencode",
        slurm_array_parallelism=1,
        slurm_nodes=1,
        slurm_gpus_per_node=1,
        slurm_tasks_per_node=1,
        slurm_cpus_per_task=NUM_WORKERS,
        slurm_partition="",
    )
    with executor.batch():
        for i in range(4):
            executor.submit(main, complexity=i)
