import os
from tqdm import tqdm
import pickle
import numpy as np
from PIL import Image
import torch
from transformers import AutoModel
import submitit
import shutil
import zipfile
from get_gen_embeddings import compute_embeddings

DATA_PATH = "../../metadata/cc12m/"

batch_size = 1024


def copydata_eval(targeted_folder):
    for i in range(9):
        shutil.copy(f"{DATA_PATH}/eval_imgs_{i}.zip", targeted_folder)
        with zipfile.ZipFile(
            f"{targeted_folder}/eval_imgs_{i}.zip", 'r'
        ) as zip_ref:
            zip_ref.extractall(targeted_folder)


def main(extractor_name, chunk, chunk_size):
    if extractor_name == "dino":
        feature_extractor = AutoModel.from_pretrained('facebook/dinov2-base')
    elif extractor_name == "inception":
        feature_extractor = None
    else:
        raise NotImplementedError

    with open(f"{DATA_PATH}/full_cc12m_eval_clean.pkl",
              "rb") as f:
        caption_dict = pickle.load(f)

    targeted_folder = f"/tmp/job{os.environ['SLURM_JOB_ID']}/"
    os.makedirs(targeted_folder, exist_ok=True)
    copydata_eval(targeted_folder)

    images = []
    image_names = list(caption_dict.keys())
    for index in tqdm(range(chunk * chunk_size, (chunk+1) * chunk_size)):
        try:
            img_path = (
                            f"{targeted_folder}/"
                            f"{image_names[index]}.jpg"
                        )
            img = Image.open(img_path).convert("RGB")
        except IndexError:
            print(f"Index not found: {index}")
            break
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            continue
        img = img.resize((256, 256))
        images.append(img)

    loaded_embeddings = compute_embeddings(
        images,
        model=feature_extractor,
        transform=None,
        batch_size=batch_size,
        device=torch.device("cuda"),
        extractor_name=extractor_name,
        cache=None,
        normalize=True,
    )

    print(loaded_embeddings.shape)

    os.makedirs(f"{DATA_PATH}/data_embeddings", exist_ok=True)
    np.save(f"{DATA_PATH}/data_embeddings/"
            f"embedding_imgs_{extractor_name}_{chunk}.npy",
            loaded_embeddings)


if __name__ == "__main__":
    executor = submitit.AutoExecutor(folder="../../outputs/vendi_logs/")
    executor.update_parameters(
        timeout_min=60 * 12,
        mem_gb=40,
        name="embedding",
        slurm_array_parallelism=10,
        slurm_nodes=1,
        slurm_gpus_per_node=1,
        slurm_tasks_per_node=1,
        slurm_cpus_per_task=10,
        slurm_partition="",
    )
    with open(f"{DATA_PATH}/full_cc12m_eval_clean.pkl",
              "rb") as f:
        caption_dict = pickle.load(f)

    image_names = list(caption_dict.keys())
    chunk_size = 50000
    num_chunks = (len(image_names) - 1) // chunk_size + 1
    with executor.batch():
        for extractor in ["dino", "inception"]:
            for chunk in range(num_chunks):
                job = executor.submit(main, extractor, chunk, chunk_size)
                print(job)
