import os
import pickle
from tqdm import tqdm
from PIL import Image
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast
import open_clip
import submitit


# change your preferred clip model here
MODEL_NAME = "ViT-SO400M-14-SigLIP-384"
PRETRAINED_DATASET = "webli"
force_quick_gelu = False

NUM_WORKERS = 10

DATA_PATH = "../metadata/"


def main(batch_size=1792):
    annot_file = f"{DATA_PATH}/cc12m/full_dict_gemma3_eval_clean_4caps.pkl"

    output_folder = f"{DATA_PATH}/cc12m/clip_embeddings/"
    os.makedirs(output_folder, exist_ok=True)

    file_name = f"{MODEL_NAME}_{PRETRAINED_DATASET}_img.npy"

    output_file = os.path.join(output_folder, file_name)

    print("Loading OpenCLIP model...")
    model, _, preprocess_img = open_clip.create_model_and_transforms(
        MODEL_NAME,
        pretrained=PRETRAINED_DATASET,
        force_quick_gelu=force_quick_gelu,
    )
    tokenizer = open_clip.get_tokenizer(MODEL_NAME)
    print("Model loaded.")

    if torch.cuda.is_available():
        device = torch.device("cuda")
        model.to(device)
    else:
        raise RuntimeError("CUDA is not available. Please check your setup.")

    class ImageDataset(Dataset):
        def __init__(self, annot_file, image_preprocess, tokenizer):
            with open(annot_file, 'rb') as f:
                self.image_dict = pickle.load(f)
            self.image_names = list(self.image_dict.keys())
            self.tokenizer = tokenizer
            self.image_preprocess = image_preprocess

        def __len__(self):
            return len(self.image_names)

        def __getitem__(self, idx):
            img_path = (
                        f"{DATA_PATH}/cc12m/images/"
                        f"{self.image_names[idx]}.jpg"
                    )
            image = self.image_preprocess(Image.open(img_path))
            return image

    print("Creating Dataset and DataLoader...")
    dataset = ImageDataset(annot_file, preprocess_img, tokenizer)
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        drop_last=False,
    )
    print("DataLoader created.")

    model.eval()

    print("Starting encoding...")
    all_image_features = []
    with torch.no_grad():
        with autocast(dtype=torch.float16):
            for images in tqdm(dataloader, desc="Encoding Batches"):
                # Skip empty batches if collate_fn filtered everything
                if images.numel() == 0:
                    continue

                # Move data to the primary GPU.
                images = images.to(device, non_blocking=True)

                # Encode images and text
                image_features = model.encode_image(images)

                # Normalize features
                image_features /= image_features.norm(dim=-1, keepdim=True)

                # Collect features
                all_image_features.append(image_features.detach().cpu())

    print("Encoding finished.")

    print("Concatenating features...")
    image_features_np = torch.cat(all_image_features).numpy()

    print(f"Save image features ({image_features_np.shape}) to {output_file}")
    np.save(output_file, image_features_np)


if __name__ == "__main__":
    executor = submitit.AutoExecutor(folder="../logs/clip_embeddings_logs")
    executor.update_parameters(
        timeout_min=12*60,
        mem_gb=10,
        name="imgclipencode",
        slurm_array_parallelism=1,
        slurm_nodes=1,
        slurm_gpus_per_node=1,
        slurm_tasks_per_node=1,
        slurm_cpus_per_task=NUM_WORKERS,
        slurm_partition="",
    )
    with executor.batch():
        executor.submit(main,)
