from datasets import load_dataset
from typing import List
import tyro
from tqdm import tqdm
import torch
import h5py

# Assume the utility functions are defined elsewhere or imported correctly
from utils.data_utils import render_text
from utils.model_init import init_subject_model
from utils.embeddings import get_image_embedding, get_text_embedding
from scripts.utils.convert_raw_to_embeddings import create_output_filename


def main(model_name: str = "openai/clip-vit-base-patch32", embedding_size: int = 768):
    dataset = load_dataset("sst2")
    model_dict = init_subject_model(model_name, "clip")

    file_name = "./text_openai_clip-vit-large-patch14.h5"
    print(file_name)
    with h5py.File(file_name, "a") as h5f:
        for split in ["train", "validation", "test"]:
            embeddings_key = f"{split}/embeddings"
            indices_key = f"{split}/indices"

            if embeddings_key not in h5f:
                embeddings_dset = h5f.create_dataset(
                    embeddings_key,
                    shape=(0, embedding_size),
                    maxshape=(None, embedding_size),
                    dtype="float32",
                )
            else:
                embeddings_dset = h5f[embeddings_key]

            if indices_key not in h5f:
                indices_dset = h5f.create_dataset(
                    indices_key, shape=(0,), maxshape=(None,), dtype="i8"
                )
            else:
                indices_dset = h5f[indices_key]

            for idx, example in tqdm(
                enumerate(dataset[split]), total=len(dataset[split])
            ):
                text = example["sentence"]

                text_embedding = get_text_embedding(model_dict, text)
                # print(image_embedding.shape)

                embeddings_dset.resize(
                    (embeddings_dset.shape[0] + 1, embeddings_dset.shape[1])
                )
                embeddings_dset[-1] = text_embedding

                indices_dset.resize((indices_dset.shape[0] + 1,))
                indices_dset[-1] = idx


if __name__ == "__main__":
    tyro.cli(main)
