import h5py
import torch
from torch.utils.data import Dataset, DataLoader
import tyro
from datasets import load_dataset
from utils.probe_models import ProbeModelWordLabelLightning
from typing import List
from lightning.pytorch.trainer import Trainer
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
import random
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split


def name_of_run(
    model_name: str,
    task: str,
    hiddens: List[int] = [],
):
    model_name = model_name.replace("/", "_")
    return f"{model_name}_{task}_{'_'.join([str(x) for x in hiddens])}"


class SST2Dataset(Dataset):
    def __init__(self, embeddings, labels):
        self.embeddings = embeddings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.embeddings[idx], self.labels[idx]


def main(
    file_name: str,
    test_file_name: str,
    model_name: str = "openai/clip-vit-large-patch14",
):
    dataset = load_dataset("sst2")

    data = {}
    with h5py.File(file_name, "r") as h5f:
        print(h5f.keys())
        for split in h5f.keys():
            embeddings = h5f[f"{split}/embeddings"][:]
            # indices = h5f[f'{split}/indices'][:]

            print(split, embeddings.shape)
            labels = dataset[split]["label"]

            # print(indices)
            # print(labels)

            # ensure the indices are in the correct range
            # assert indices.max() < len(labels), f"Index out of range for {split}"

            data[split] = {
                "embeddings": embeddings,
                "labels": labels,
            }

    train_embeddings = data["train"]["embeddings"]
    train_labels = data["train"]["labels"]

    # validation
    val_embeddings = data["validation"]["embeddings"]
    val_labels = data["validation"]["labels"]

    val_embeddings_new, test_embeddings, val_labels_new, test_labels = train_test_split(
        val_embeddings, val_labels, test_size=0.5, random_state=42
    )

    # Update the data dictionary
    data["validation"] = {
        "embeddings": val_embeddings_new,
        "labels": val_labels_new,
    }

    data["test"] = {
        "embeddings": test_embeddings,
        "labels": test_labels,
    }

    train_embeddings = data["train"]["embeddings"]
    train_labels = data["train"]["labels"]

    val_embeddings = data["validation"]["embeddings"]
    val_labels = data["validation"]["labels"]

    # test_embeddings = data['test']['embeddings']
    # test_labels = data['test']['labels']

    # split data['test'] into train (0.8) and validation (leftover 0.2)
    # test_embeddings = data['test']['embeddings']
    # test_labels = data['test']['labels']

    print(len(test_embeddings), len(test_labels))

    # new test
    test_data = {}
    with h5py.File(test_file_name, "r") as h5f:
        print(h5f.keys())
        for split in h5f.keys():
            embeddings = h5f[f"{split}/embeddings"][:]
            # indices = h5f[f'{split}/indices'][:]

            print(split, embeddings.shape)
            labels = dataset[split]["label"]

            # print(indices)
            # print(labels)

            # ensure the indices are in the correct range
            # assert indices.max() < len(labels), f"Index out of range for {split}"

            test_data[split] = {
                "embeddings": embeddings,
                "labels": labels,
            }

    # validation
    test_val_embeddings = test_data["validation"]["embeddings"]
    test_val_labels = test_data["validation"]["labels"]

    _, test_embeddings, _, test_labels = train_test_split(
        test_val_embeddings, test_val_labels, test_size=0.5, random_state=42
    )

    test_data["test"] = {
        "embeddings": test_embeddings,
        "labels": test_labels,
    }

    test_embeddings = test_data["test"]["embeddings"]
    test_labels = test_data["test"]["labels"]

    # split data['test'] into train (0.8) and validation (leftover 0.2)
    # test_embeddings = data['test']['embeddings']
    # test_labels = data['test']['labels']

    print(len(test_embeddings), len(test_labels))

    # shuffled_embeddings =
    # shuffled_labels =
    # shuffled_embeddings, shuffled_labels = shuffle(test_embeddings, test_labels, random_state=42)

    # print(shuffled_labels)

    # split_idx = int(0.8 * len(shuffled_embeddings))
    # train_embeddings = shuffled_embeddings[:split_idx]
    # train_labels = shuffled_labels[:split_idx]

    # val_embeddings = shuffled_embeddings[split_idx:]
    # val_labels = shuffled_labels[split_idx:]

    # create train, val and test dataloaders
    train_dataset = SST2Dataset(train_embeddings, train_labels)
    train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)

    validation_dataset = SST2Dataset(val_embeddings, val_labels)
    validation_dataloader = DataLoader(validation_dataset, batch_size=64, shuffle=True)

    test_dataset = SST2Dataset(test_embeddings, test_labels)
    test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)

    # test_dataset = SST2Dataset(data['validation_mismatched']['embeddings'], data['validation_mismatched']['labels'])
    # test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)

    # for batch in train_dataloader:
    #     print(batch)
    #     break

    # train the model
    input_dim = 768
    output_dim = 2
    hidden_dims = []
    non_linearity = "relu"

    model = ProbeModelWordLabelLightning(
        input_dim=input_dim,
        output_dim=output_dim,
        hidden_dims=hidden_dims,
        non_linearity=non_linearity,
    )

    checkpoint_callback = ModelCheckpoint(
        monitor="val_f1",
        mode="max",
        save_top_k=1,
        save_last=True,
    )

    run_name = name_of_run(
        model_name=model_name,
        task=f"sst2-openclip-large-sourcecodepro-{file_name}",
        hiddens=[],
    )
    wandb_logger = WandbLogger(name=run_name, project="probing-rendered-text")

    trainer = Trainer(
        logger=wandb_logger,
        accelerator="gpu",
        devices=1,
        max_epochs=10,
        callbacks=[checkpoint_callback],
        enable_progress_bar=False,
    )
    # trainer = Trainer(accelerator='gpu', max_epochs=100)
    trainer.fit(model, train_dataloader, validation_dataloader)

    best_model_path = checkpoint_callback.best_model_path
    print(best_model_path)
    model = ProbeModelWordLabelLightning.load_from_checkpoint(best_model_path)

    # evaluate the model
    metrics = trainer.test(model, test_dataloader, ckpt_path="best")
    print(">" * 50)
    print("Trained On:", file_name)
    print("Tested On:", test_file_name)
    print(
        f"{metrics[0]['val_f1']:.3f},{metrics[0]['val_precision']:.3f},{metrics[0]['val_recall']:.3f}"
    )
    print("<" * 50)


if __name__ == "__main__":
    tyro.cli(main)
