from datasets import load_dataset
from typing import List
import tyro
from tqdm import tqdm
import torch
import h5py
import codecs
from huggingface_hub import login

import json

# Import open_clip instead of Huggingface's transformers
import open_clip

# Assume the utility functions are defined elsewhere or imported correctly
from utils.data_utils import render_text, shuffle_text

# Assume the utility functions are defined elsewhere or imported correctly
from utils.data_utils import render_text, shuffle_text
from utils.model_init import init_subject_model
from utils.embeddings import get_image_embedding, get_text_embedding
from utils.convert_raw_to_embeddings import create_output_filename
from transformers import AutoTokenizer, AutoModelForCausalLM


def init_llama_model(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    model = AutoModelForCausalLM.from_pretrained(model_name)
    # model.to("cuda")
    model.eval()  # set the model to evaluation mode
    return {"tokenizer": tokenizer, "model": model}


@torch.no_grad()
def get_llama_text_embedding(model_dict, text: str):
    inputs = model_dict["tokenizer"](text, padding=True, return_tensors="pt").to(
        model_dict["model"].device
    )
    outputs = model_dict["model"](**inputs, output_hidden_states=True)
    last_hidden_state = outputs.hidden_states[-1][:, -1, :].squeeze(0)
    # print(last_hidden_state.shape)

    # Instead of averaging, return the last hidden state directly
    return last_hidden_state


def init_subject_model_open_clip(model_name: str, pretrained: str = "openai"):
    # model, _, preprocess = open_clip.create_model_and_transforms('xlm-roberta-base-ViT-B-32', pretrained='laion5b_s13b_b90k')
    # tokenizer = open_clip.get_tokenizer('xlm-roberta-base-ViT-B-32')
    model, _, preprocess = open_clip.create_model_and_transforms(
        "ViT-bigG-14", pretrained="laion2b_s39b_b160k"
    )
    tokenizer = open_clip.get_tokenizer("ViT-bigG-14")
    return {"model": model, "preprocess": preprocess, "tokenizer": tokenizer}


def create_output_filename(dataset_name: str, directory: str, model_name: str) -> str:
    # Your logic to create a filename
    return f"{directory}/{model_name.replace('/', '_')}_embeddings.h5"


def rot_k(text, k):
    result = ""
    for char in text:
        if char.isalpha():
            ascii_offset = 65 if char.isupper() else 97
            rotated = chr((ord(char) - ascii_offset + k) % 26 + ascii_offset)
            result += rotated
        else:
            result += char
    return result


def main(
    model_name: str = "laion2b_s39b_b160k",
    embedding_size: int = 1280,  # Adjust embedding size according to the new model
    font_path: str = "fonts/Roboto-Regular.ttf",
    jumble: bool = False,
    rot13: bool = False,
    rot9: bool = False,
):
    with open("data/mscoco_captions_val2017.json", "r") as f:
        data = json.load(f)

    all_captions = [annotation["caption"] for annotation in data["annotations"]]
    # dataset = load_dataset("sst2")

    llama_model_dict = init_llama_model("meta-llama/Meta-Llama-3-8B")
    model_dict = init_subject_model_open_clip(model_name)
    model_dict["model"] = model_dict["model"].to("cuda")

    font_name = font_path.split("/")[-1].split(".")[0]
    file_name = f"new_laion2b_mscoco_llama_{model_name}_image_embeddings_{font_name}_{'jumbled' if jumble else ''}_sst2_{'rot13' if rot13 else ''}_{'rot9' if rot9 else ''}.h5".replace(
        "/", "_"
    )
    print(file_name)

    with h5py.File(file_name, "w") as h5f:
        # for split in ['train', 'validation', 'test']:
        # embeddings_key = f'{split}/embeddings'
        text_embeddings_key = f"clip_text_embeddings"
        image_embeddings_key = f"clip_image_embeddings"
        llama_text_embeddings_key = f"llama_text_embeddings"

        clip_embedding_size = 1280
        llama_embedding_size = 4096  # Adjust if needed based on LLaMA model dimensions

        text_embeddings_dset = h5f.create_dataset(
            text_embeddings_key,
            shape=(0, clip_embedding_size),
            maxshape=(None, clip_embedding_size),
            dtype="float32",
        )
        image_embeddings_dset = h5f.create_dataset(
            image_embeddings_key,
            shape=(0, clip_embedding_size),
            maxshape=(None, clip_embedding_size),
            dtype="float32",
        )
        llama_text_embeddings_dset = h5f.create_dataset(
            llama_text_embeddings_key,
            shape=(0, llama_embedding_size),
            maxshape=(None, llama_embedding_size),
            dtype="float32",
        )

        # if embeddings_key not in h5f:
        #     embeddings_dset = h5f.create_dataset(embeddings_key, shape=(0, embedding_size), maxshape=(None, embedding_size), dtype='float32')
        # else:
        #     embeddings_dset = h5f[embeddings_key]

        # only get first 10k samples from split
        # sentences = [example["sentence"] for example in dataset[split]]

        # for idx, text in tqdm(enumerate(sentences), total=len(sentences)):
        #     # text = example["sentence"]
        for idx, example in tqdm(
            enumerate(all_captions[:1000]), total=len(all_captions[:1000])
        ):
            text = example

            if jumble:
                text = shuffle_text(text)
            elif rot13:
                text = codecs.encode(text, "rot_13")
            elif rot9:
                text = rot_k(text, k=9)

            image, ok = render_text(text, font_path=font_path)
            if not ok:
                print("OVERFLOWS", text, flush=True)
                continue  # Now correctly skips to the next iteration

            text_embedding = get_text_embedding(model_dict, text)
            image_embedding = get_image_embedding(model_dict, image)
            llama_text_embedding = get_llama_text_embedding(llama_model_dict, text)

            text_embeddings_dset.resize(
                (text_embeddings_dset.shape[0] + 1, text_embeddings_dset.shape[1])
            )
            image_embeddings_dset.resize(
                (image_embeddings_dset.shape[0] + 1, image_embeddings_dset.shape[1])
            )
            llama_text_embeddings_dset.resize(
                (
                    llama_text_embeddings_dset.shape[0] + 1,
                    llama_text_embeddings_dset.shape[1],
                )
            )

            text_embeddings_dset[idx] = text_embedding.cpu().numpy()
            image_embeddings_dset[idx] = image_embedding.cpu().numpy()
            llama_text_embeddings_dset[idx] = llama_text_embedding.cpu().numpy()


if __name__ == "__main__":
    tyro.cli(main)
