import h5py
import torch
from torch.utils.data import Dataset, DataLoader
import tyro
from datasets import load_dataset
from utils.probe_models import ProbeModelWordLabelLightning
from typing import List
from lightning.pytorch.trainer import Trainer
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
import random
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split


def name_of_run(
    model_name: str,
    task: str,
    hiddens: List[int] = [],
):
    model_name = model_name.replace("/", "_")
    return f"{model_name}_{task}_{'_'.join([str(x) for x in hiddens])}"


class SST2Dataset(Dataset):
    def __init__(self, embeddings, labels):
        self.embeddings = embeddings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.embeddings[idx], self.labels[idx]


def main(
    file_name: str,
    model_name: str = "openai/clip-vit-large-patch14",
):
    dataset = load_dataset("mattymchen/mr")

    data = {}
    with h5py.File(file_name, "r") as h5f:
        print(h5f.keys())
        for split in h5f.keys():
            embeddings = h5f[f"{split}/embeddings"][:]
            # indices = h5f[f'{split}/indices'][:]

            print(split, embeddings.shape)
            labels = dataset[split]["label"]

            # print(indices)
            # print(labels)

            # ensure the indices are in the correct range
            # assert indices.max() < len(labels), f"Index out of range for {split}"

            data[split] = {
                "embeddings": embeddings,
                "labels": labels,
            }

    # train_embeddings = data['train']['embeddings']
    # train_labels = data['train']['labels']

    # # validation
    # val_embeddings = data['validation']['embeddings']
    # val_labels = data['validation']['labels']

    # val_embeddings_new, test_embeddings, val_labels_new, test_labels = train_test_split(
    #     val_embeddings, val_labels, test_size=0.5, random_state=42
    # )

    # # Update the data dictionary
    # data['validation'] = {
    #     'embeddings': val_embeddings_new,
    #     'labels': val_labels_new,
    # }

    # data['test'] = {
    #     'embeddings': test_embeddings,
    #     'labels': test_labels,
    # }

    # train_embeddings = data['train']['embeddings']
    # train_labels = data['train']['labels']

    # val_embeddings = data['validation']['embeddings']
    # val_labels = data['validation']['labels']

    # test_embeddings = data['test']['embeddings']
    # test_labels = data['test']['labels']

    # split data['test'] into train (0.8) and validation (leftover 0.2)
    test_embeddings = data["test"]["embeddings"]
    test_labels = data["test"]["labels"]

    print(len(test_embeddings), len(test_labels))

    # shuffled_embeddings =
    # shuffled_labels =
    shuffled_embeddings, shuffled_labels = shuffle(
        test_embeddings, test_labels, random_state=42
    )

    # print(shuffled_labels)

    split_idx = int(0.8 * len(shuffled_embeddings))
    train_embeddings = shuffled_embeddings[:split_idx]
    train_labels = shuffled_labels[:split_idx]

    val_embeddings = shuffled_embeddings[split_idx:]
    val_labels = shuffled_labels[split_idx:]

    split_idx2 = int(0.5 * len(val_embeddings))
    test_embeddings = val_embeddings[split_idx2:]
    test_labels = val_labels[split_idx2:]

    val_embeddings = val_embeddings[:split_idx2]
    val_labels = val_labels[:split_idx2]

    print(len(train_embeddings), len(train_labels))
    print(len(val_embeddings), len(val_labels))
    print(len(test_embeddings), len(test_labels))

    # create train, val and test dataloaders
    train_dataset = SST2Dataset(train_embeddings, train_labels)
    train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)

    validation_dataset = SST2Dataset(val_embeddings, val_labels)
    validation_dataloader = DataLoader(validation_dataset, batch_size=64, shuffle=True)

    test_dataset = SST2Dataset(test_embeddings, test_labels)
    test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)

    # test_dataset = SST2Dataset(data['validation_mismatched']['embeddings'], data['validation_mismatched']['labels'])
    # test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)

    # for batch in train_dataloader:
    #     print(batch)
    #     break

    # train the model
    input_dim = 1280
    output_dim = 2
    hidden_dims = []
    non_linearity = "relu"

    model = ProbeModelWordLabelLightning(
        input_dim=input_dim,
        output_dim=output_dim,
        hidden_dims=hidden_dims,
        non_linearity=non_linearity,
    )

    checkpoint_callback = ModelCheckpoint(
        monitor="val_f1",
        mode="max",
        save_top_k=1,
        save_last=True,
    )

    run_name = name_of_run(
        model_name=model_name,
        task=f"sst2-openclip-large-sourcecodepro-{file_name}",
        hiddens=[],
    )
    wandb_logger = WandbLogger(name=run_name, project="probing-rendered-text")

    trainer = Trainer(
        logger=wandb_logger,
        accelerator="gpu",
        max_epochs=10,
        callbacks=[checkpoint_callback],
    )
    # trainer = Trainer(accelerator='gpu', max_epochs=100)
    trainer.fit(model, train_dataloader, validation_dataloader)

    best_model_path = checkpoint_callback.best_model_path
    print(best_model_path)
    model = ProbeModelWordLabelLightning.load_from_checkpoint(best_model_path)

    # evaluate the model
    metrics = trainer.test(model, test_dataloader, ckpt_path="best")
    print(metrics)
    print(
        f"{metrics[0]['val_f1']:.3f},{metrics[0]['val_precision']:.3f},{metrics[0]['val_recall']:.3f}"
    )


if __name__ == "__main__":
    tyro.cli(main)
