import h5py
import numpy as np
import torch
from collections import defaultdict
from torch.utils.data import Dataset, DataLoader
import tyro
from datasets import load_dataset
from utils.probe_models import ProbeModelWordLabelLightning
from typing import List
from lightning.pytorch.trainer import Trainer
from lightning.pytorch.loggers import WandbLogger
import random
from sklearn.utils import shuffle


def name_of_run(
    model_name: str,
    task: str,
    hiddens: List[int] = [],
):
    model_name = model_name.replace("/", "_")
    return f"{model_name}_{task}_{'_'.join([str(x) for x in hiddens])}"


class SST2FontsDataset(Dataset):
    def __init__(self, embeddings, labels):
        self.embeddings = embeddings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.embeddings[idx], self.labels[idx]


def main(
    # file_name: str,
    model_name: str = "openai/clip-vit-large-patch14",
):
    dataset = load_dataset("sst2")

    file_names = [
        "openai_clip-vit-large-patch14_image_embeddings_CedarvilleCursive-Regular__sst2_.h5",
        "openai_clip-vit-large-patch14_image_embeddings_JustAnotherHand-Regular__sst2_.h5",
        "openai_clip-vit-large-patch14_image_embeddings_Roboto-Regular__sst2__.h5",
        "openai_clip-vit-large-patch14_image_embeddings_Tiny5-Regular__sst2_.h5",
    ]

    data = {
        "train": {"embeddings": [], "labels": []},
        "validation": {"embeddings": [], "labels": []},
    }
    for i, file_name in enumerate(file_names):
        data_splitwise = {}
        with h5py.File(file_name, "r") as h5f:
            train_embeddings = h5f[f"train/embeddings"][:]
            # test_labels = h5f[f'test/indices'][:]

            print(len(train_embeddings))

            # shuffled_embeddings =
            # shuffled_labels =
            # shuffled_embeddings, shuffled_labels = shuffle(train_embeddings, test_labels, random_state=42)
            shuffled_embeddings = shuffle(train_embeddings, random_state=42)

            # print(shuffled_labels)

            split_idx = int(0.8 * len(shuffled_embeddings))
            train_embeddings = shuffled_embeddings[:split_idx]

            val_embeddings = shuffled_embeddings[split_idx:]

            train_labels = torch.LongTensor([i] * len(train_embeddings))
            val_labels = torch.LongTensor([i] * len(val_embeddings))

            data["train"]["embeddings"].append(train_embeddings)
            data["train"]["labels"].append(train_labels)

            data["validation"]["embeddings"].append(val_embeddings)
            data["validation"]["labels"].append(val_labels)

    # make data object from font data
    for split in ["train", "validation"]:
        data[split]["embeddings"] = torch.tensor(
            np.concatenate(data[split]["embeddings"])
        )
        data[split]["labels"] = torch.cat(data[split]["labels"])

        print(data[split]["labels"])

        # print(split, data[split]['embeddings'].shape, data[split]['labels'].shape)

    # create train, val and test dataloaders
    train_dataset = SST2FontsDataset(
        data["train"]["embeddings"], data["train"]["labels"]
    )
    train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)

    validation_dataset = SST2FontsDataset(
        data["validation"]["embeddings"], data["validation"]["labels"]
    )
    validation_dataloader = DataLoader(validation_dataset, batch_size=64, shuffle=True)

    # train the model
    input_dim = 768
    output_dim = 4
    hidden_dims = []
    non_linearity = "relu"

    model = ProbeModelWordLabelLightning(
        input_dim=input_dim,
        output_dim=output_dim,
        hidden_dims=hidden_dims,
        non_linearity=non_linearity,
    )

    run_name = name_of_run(
        model_name=model_name, task="sst2-openai-large-fonts-classification", hiddens=[]
    )
    wandb_logger = WandbLogger(name=run_name, project="probe_clip_compos_models")

    trainer = Trainer(
        logger=wandb_logger,
        accelerator="gpu",
        max_epochs=100,
    )
    # trainer = Trainer(accelerator='gpu', max_epochs=100)
    trainer.fit(model, train_dataloader, validation_dataloader)


if __name__ == "__main__":
    tyro.cli(main)
