from datasets import load_dataset
from typing import List
import tyro
from tqdm import tqdm
from transformers import XLMRobertaTokenizer, XLMRobertaModel
import torch
import h5py
import codecs

# Assume the utility functions are defined elsewhere or imported correctly
from utils.data_utils import render_text, shuffle_text
from utils.model_init import init_subject_model
from utils.embeddings import get_image_embedding, get_text_embedding
from utils.convert_raw_to_embeddings import create_output_filename


# load roberta model
def init_roberta_model(model_name):
    tokenizer = XLMRobertaTokenizer.from_pretrained(model_name)
    model = XLMRobertaModel.from_pretrained(model_name).to("cuda")
    model.eval()  # set the model to evaluation mode
    return {"tokenizer": tokenizer, "model": model}


@torch.no_grad()
def get_roberta_text_embedding(model_dict, text: str):
    inputs = model_dict["tokenizer"](text, padding=True, return_tensors="pt").to(
        model_dict["model"].device
    )
    outputs = model_dict["model"](**inputs)
    last_hidden_state = outputs.last_hidden_state
    # Take the mean of the token embeddings to get a single vector
    avg_embedding = torch.mean(last_hidden_state, dim=1)
    return avg_embedding


def rot_k(text, k):
    result = ""
    for char in text:
        if char.isalpha():
            ascii_offset = 65 if char.isupper() else 97
            rotated = chr((ord(char) - ascii_offset + k) % 26 + ascii_offset)
            result += rotated
        else:
            result += char
    return result


def main(
    model_name: str = "openai/clip-vit-large-patch14",
    embedding_size: int = 768,
    font_path: str = "fonts/Roboto-Regular.ttf",
    jumble: bool = False,
    rot13: bool = False,
    rot9: bool = False,
):
    dataset = load_dataset("sst2")
    model_dict = init_subject_model(model_name, "clip", device="cuda")
    roberta_model_dict = init_roberta_model("FacebookAI/xlm-roberta-base")

    # file_name = create_output_filename('sst2', '.', model_name)
    font_name = font_path.split("/")[-1].split(".")[0]
    file_name = f"roberta_base_{model_name}_image_embeddings_{font_name}_{'jumbled' if jumble else ''}_sst2_{'rot13' if rot13 else ''}_{'rot9' if rot9 else ''}.h5".replace(
        "/", "_"
    )
    print(file_name)
    with h5py.File(file_name, "a") as h5f:
        for split in ["train", "validation", "test"]:
            # embeddings_key = f'{split}/embeddings'
            text_embeddings_key = f"clip_text_embeddings"
            image_embeddings_key = f"clip_image_embeddings"
            roberta_text_embeddings_key = f"roberta_text_embeddings"

            clip_embedding_size = 768
            roberta_embedding_size = 768

            # if embeddings_key not in h5f:
            #     embeddings_dset = h5f.create_dataset(embeddings_key, shape=(0, embedding_size), maxshape=(None, embedding_size), dtype='float32')
            # else:
            #     embeddings_dset = h5f[embeddings_key]

            text_embeddings_dset = h5f.create_dataset(
                text_embeddings_key,
                shape=(0, clip_embedding_size),
                maxshape=(None, clip_embedding_size),
                dtype="float32",
            )
            image_embeddings_dset = h5f.create_dataset(
                image_embeddings_key,
                shape=(0, clip_embedding_size),
                maxshape=(None, clip_embedding_size),
                dtype="float32",
            )
            roberta_text_embeddings_dset = h5f.create_dataset(
                roberta_text_embeddings_key,
                shape=(0, roberta_embedding_size),
                maxshape=(None, roberta_embedding_size),
                dtype="float32",
            )

            for idx, example in tqdm(
                enumerate(dataset[split]), total=len(dataset[split])
            ):
                text = example["sentence"]

                # print(text, flush=True)

                if jumble:
                    text = shuffle_text(text)
                elif rot13:
                    text = codecs.encode(text, "rot_13")
                elif rot9:
                    text = rot_k(text, k=9)

                # print(text, flush=True)

                image, ok = render_text(text, font_path=font_path)
                if not ok:
                    print("OVERFLOWS", text, flush=True)
                    # continue  # Now correctly skips to the next iteration

                image_embedding = get_image_embedding(model_dict, image)
                # print(image_embedding.shape)

                # embeddings_dset.resize((embeddings_dset.shape[0] + 1, embeddings_dset.shape[1]))
                # embeddings_dset[-1] = image_embedding.cpu().numpy()

                text_embedding = get_text_embedding(model_dict, text)
                image_embedding = get_image_embedding(model_dict, image)
                roberta_text_embedding = get_roberta_text_embedding(
                    roberta_model_dict, text
                )

                #  embeddings_dset.resize((embeddings_dset.shape[0] + 1, embeddings_dset.shape[1]))
                text_embeddings_dset.resize(
                    (text_embeddings_dset.shape[0] + 1, text_embeddings_dset.shape[1])
                )
                image_embeddings_dset.resize(
                    (image_embeddings_dset.shape[0] + 1, image_embeddings_dset.shape[1])
                )
                roberta_text_embeddings_dset.resize(
                    (
                        roberta_text_embeddings_dset.shape[0] + 1,
                        roberta_text_embeddings_dset.shape[1],
                    )
                )

                text_embeddings_dset[idx] = text_embedding.cpu().numpy()
                image_embeddings_dset[idx] = image_embedding.cpu().numpy()
                roberta_text_embeddings_dset[idx] = roberta_text_embedding.cpu().numpy()


if __name__ == "__main__":
    tyro.cli(main)
