import h5py
import torch
from torch.utils.data import Dataset, DataLoader
import tyro
from datasets import load_dataset
from utils.probe_models import ProbeModelWordLabelLightning
from typing import List
from lightning.pytorch.trainer import Trainer
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split


def name_of_run(
    model_name: str,
    task: str,
    hiddens: List[int] = [],
):
    model_name = model_name.replace("/", "_")
    return f"{model_name}_{task}_{'_'.join([str(x) for x in hiddens])}"


class SST2Dataset(Dataset):
    def __init__(self, embeddings, labels):
        self.embeddings = embeddings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.embeddings[idx], self.labels[idx]


def main(
    train_file_name: str,
    test_file_name: str,
    model_name: str = "openai/clip-vit-large-patch14",
):
    dataset = load_dataset("mattymchen/mr")

    # Load the train dataset (used for training and validation)
    train_data = {}
    with h5py.File(train_file_name, "r") as h5f:
        print(h5f.keys())
        for split in h5f.keys():
            embeddings = h5f[f"{split}/embeddings"][:]
            print(split, embeddings.shape)
            labels = dataset[split]["label"]

            train_data[split] = {
                "embeddings": embeddings,
                "labels": labels,
            }

    # Since the dataset only has "test" split, we need to create train and validation sets from it
    test_embeddings = train_data["test"]["embeddings"]
    test_labels = train_data["test"]["labels"]

    shuffled_embeddings, shuffled_labels = shuffle(
        test_embeddings, test_labels, random_state=42
    )

    split_idx = int(0.8 * len(shuffled_embeddings))
    train_embeddings = shuffled_embeddings[:split_idx]
    train_labels = shuffled_labels[:split_idx]

    val_embeddings = shuffled_embeddings[split_idx:]
    val_labels = shuffled_labels[split_idx:]

    print("Train Set:", len(train_embeddings), len(train_labels))
    print("Validation Set:", len(val_embeddings), len(val_labels))

    # Load the test dataset (used only for final testing)
    test_data = {}
    with h5py.File(test_file_name, "r") as h5f:
        print(h5f.keys())
        for split in h5f.keys():
            embeddings = h5f[f"{split}/embeddings"][:]
            print(split, embeddings.shape)
            labels = dataset[split]["label"]

            test_data[split] = {
                "embeddings": embeddings,
                "labels": labels,
            }

    # Use the entire "test" split of test_data as the test set
    test_embeddings = test_data["test"]["embeddings"]
    test_labels = test_data["test"]["labels"]

    print("Test Set:", len(test_embeddings), len(test_labels))

    # Create DataLoaders for train, validation, and test sets
    train_dataset = SST2Dataset(train_embeddings, train_labels)
    train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)

    validation_dataset = SST2Dataset(val_embeddings, val_labels)
    validation_dataloader = DataLoader(validation_dataset, batch_size=64, shuffle=True)

    test_dataset = SST2Dataset(test_embeddings, test_labels)
    test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)

    # Train the model
    input_dim = 768
    output_dim = 2
    hidden_dims = []
    non_linearity = "relu"

    model = ProbeModelWordLabelLightning(
        input_dim=input_dim,
        output_dim=output_dim,
        hidden_dims=hidden_dims,
        non_linearity=non_linearity,
    )

    checkpoint_callback = ModelCheckpoint(
        monitor="val_f1",
        mode="max",
        save_top_k=1,
        save_last=True,
    )

    run_name = name_of_run(
        model_name=model_name,
        task=f"mr-openclip-large-{train_file_name}",
        hiddens=[],
    )
    wandb_logger = WandbLogger(name=run_name, project="probing-rendered-text")

    trainer = Trainer(
        logger=wandb_logger,
        accelerator="gpu",
        devices=1,
        max_epochs=10,
        callbacks=[checkpoint_callback],
        enable_progress_bar=False,
    )
    trainer.fit(model, train_dataloader, validation_dataloader)

    best_model_path = checkpoint_callback.best_model_path
    print("Best Model Path:", best_model_path)
    model = ProbeModelWordLabelLightning.load_from_checkpoint(best_model_path)

    # Evaluate the model on the test set
    metrics = trainer.test(model, test_dataloader, ckpt_path="best")
    print(">" * 50)
    print("Trained On:", train_file_name)
    print("Tested On:", test_file_name)
    print(
        f"{metrics[0]['val_f1']:.3f},{metrics[0]['val_precision']:.3f},{metrics[0]['val_recall']:.3f}"
    )
    print("<" * 50)


if __name__ == "__main__":
    tyro.cli(main)
