from collections import namedtuple

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
from datasets import load_dataset
from persim import plot_diagrams
from prodigyopt import Prodigy
from ripser import ripser
from sentence_transformers import SentenceTransformer
from torch.utils.data import DataLoader

from src.baseline_sep_metrics import (
    balanced_accuracy_index,
    kmeans_ch_score,
    roc_auc_index,
    thornton_separability_index,
)
from src.gpu_utils import accuracy_gpu_safe, memory_stats, roc_auc_gpu_safe
from src.st_classifier import STClassifier


def run_exp(
    shuffle_seed: int,
    estimator,
    d_coef: float = 1e-1,
    epochs: int = 6,
    model_name="minilm",
    dataset_name="SetFit/amazon_counterfactual",
    normalize: bool = False,
    produce_plots: bool = True,
    batch_size: int = 32,
    plot_persistence_diagrams: bool = False,
    normalise_h0: bool = True,
):
    """
    Main experiment runner for the binary-classification experiment.

    :param shuffle_seed: seed to use for shuffling the data.
    :param estimator: callable, should take as input a tensor of floats and give a float
    :param d_coef: prodigy-opt parameter
    :param epochs: int, self-explanatory
    :param model_name: str, shorthand for a sentence-transformer
    :param normalize: bool, whether to normalize the embeddings or not
    :param produce_plots: bool, whether to show plots
    :param batch_size: int, self-explanatory
    :returns: namedtuple with results from run
    """

    dataset = load_dataset(dataset_name, split="train").shuffle(shuffle_seed)

    # split into train and test
    train_dataset = dataset.select(range(1000))
    test_dataset = dataset.select(range(1000, 3000))
    data_to_embed = test_dataset.shuffle(42).select(range(1000))

    test_dataset = test_dataset.select(range(300))

    # pass the train_dataset into a dataloader
    train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)

    train_dataloader_scores = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=False
    )

    embed_dataloader = DataLoader(
        data_to_embed, batch_size=batch_size, shuffle=False
    )  # we don't want this to be shuffled
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    EPOCHS = epochs

    memory_stats()
    embeddings = []

    st_model = SentenceTransformer(model_name, device=DEVICE)
    classifier = STClassifier(model_name, st_model, device=DEVICE, normalize=normalize)
    classifier.eval()

    print(f"model name: {model_name}")
    print(f"Dataset used: {dataset_name}")
    print(f"classifier device: {classifier.model.device}")

    with torch.no_grad():
        # compute the scores on test set
        scores = torch.cat(
            [classifier(test_batch["text"]) for test_batch in test_dataloader],
            dim=0,
        )

        print(
            f"auc on test_dataset = {roc_auc_gpu_safe(test_dataset['label'], scores):1.2f}"
        )

        print(
            f"accuracy on the test_dataset = {accuracy_gpu_safe(test_dataset['label'], scores): 1.2f}"
        )

        scores = torch.cat(
            [classifier(test_batch["text"]) for test_batch in train_dataloader_scores],
            dim=0,
        )
        print(
            f"accuracy on the train_dataset = {accuracy_gpu_safe(train_dataset['label'], scores): 1.2f}"
        )

    loss_fn = nn.BCELoss()

    optimizer = Prodigy(classifier.parameters(), d_coef=d_coef)

    # 0 epoch
    classifier.eval()
    with torch.no_grad():
        embeddings.append(
            torch.cat(
                [classifier.encode(ebatch["text"]) for ebatch in embed_dataloader],
                dim=0,
            ).to("cpu")
        )

    classifier.train()

    # Train the STClassifier on the training set and embed the "data_to_embed" set at
    # every epoch
    print("Training ... ")
    for epoch in range(EPOCHS):
        for i, batch in enumerate(train_dataloader):
            # get the inputs and labels
            inputs = batch["text"]
            labels = batch["label"].to(DEVICE)

            # compute the scores
            scores = classifier(inputs)

            # compute the loss
            loss = loss_fn(scores, labels.unsqueeze(1).float())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # keep the embeddings for later.
        with torch.no_grad():
            embeddings.append(
                torch.cat(
                    [classifier.encode(ebatch["text"]) for ebatch in embed_dataloader],
                    dim=0,
                ).to("cpu")
            )

    classifier.eval()

    with torch.no_grad():
        scores = torch.cat(
            [classifier(test_batch["text"]) for test_batch in test_dataloader],
            dim=0,
        )
        print(
            f"auc on test_dataset = {roc_auc_gpu_safe(test_dataset['label'], scores, multi_class='ovo'):1.2f}"
        )

        print(
            f"accuracy on the test_dataset = {accuracy_gpu_safe(test_dataset['label'], scores): 1.2f}"
        )

        scores = torch.cat(
            [classifier(test_batch["text"]) for test_batch in train_dataloader_scores],
            dim=0,
        )
        print(
            f"accuracy on the train_dataset = {accuracy_gpu_safe(train_dataset['label'], scores): 1.2f}"
        )

    print("DONE")

    stacked_embeddings = torch.stack(embeddings).to("cpu")
    y_true = data_to_embed["label"]

    N_TRIALS, _, _ = stacked_embeddings.shape

    thorntons = []
    kmeans_score = []
    roc_auc_score_list = []
    accuracy_list = []

    h0 = []
    prob_h0 = []

    if plot_persistence_diagrams:
        _, axs = plt.subplots(1, N_TRIALS, figsize=(10, 10))

    for i in range(N_TRIALS):
        X = stacked_embeddings[i, :, :]

        # baselines
        roc_auc_value = roc_auc_index(X, y_true, n_splits=30, max_iter=2000)
        accuracy_score = balanced_accuracy_index(X, y_true, n_splits=30, max_iter=2000)

        accuracy_list.append((i, accuracy_score))
        roc_auc_score_list.append((i, roc_auc_value))
        kmeans_score.append(kmeans_ch_score(X, n_clusters=10))

        thorntons.append(thornton_separability_index(X, y_true, n_neighbors=10))

        # ripser
        diagrams = ripser(X)["dgms"]

        if normalise_h0:
            diagrams[0][:, 1] /= diagrams[0][-2, 1]  # -1 is inf

        h0_diag_data = diagrams[0]

        if plot_persistence_diagrams:
            plot_diagrams(diagrams, plot_only=[0], ax=axs[i])

        h0_without_inf = h0_diag_data[:-1, 1]
        h0.append(h0_without_inf)

        stat = estimator(h0_without_inf)
        prob_h0.append(stat)

    scores = pd.DataFrame(
        {
            "Thornton_(supervised)": thorntons,
            "Calinski-Harabasz_(unsupervised)": kmeans_score / max(kmeans_score),
            "persistence_score": prob_h0,
        }
    )

    roc_auc_score_df = pd.DataFrame(
        roc_auc_score_list, columns=["epoch", "roc_auc_score"]
    ).explode("roc_auc_score")

    accuracy_score_df = pd.DataFrame(
        accuracy_list, columns=["epoch", "accuracy"]
    ).explode("accuracy")

    if produce_plots:
        sns.set_palette("bright")
        plt.tight_layout()

        scores.plot(linestyle="-", marker="o", markersize=5)
        sns.lineplot(
            roc_auc_score_df,
            x="epoch",
            y="roc_auc_score",
            marker="o",
            markersize=5,
            label="ROC-AUC-30",
        )

        sns.lineplot(
            accuracy_score_df,
            x="epoch",
            y="accuracy",
            marker="o",
            markersize=5,
            label="accuracy-30",
        )

        plt.legend(loc=0, fancybox=True)
        plt.axhline(y=1, linestyle="--", color="k")
        plt.xlabel("epoch")
        plt.ylabel("scores")
        sns.despine()

        plt.xticks(range(N_TRIALS))

        plt.show()

    to_return = namedtuple(
        "BinaryExpResults",
        [
            "scores_df",
            "roc_auc_score_df",
            "accuracy_score_df",
            "h0",
            "stacked_embeddings",
            "original_data_to_embed",
        ],
    )

    return to_return(
        scores,
        roc_auc_score_df,
        accuracy_score_df,
        h0,
        stacked_embeddings,
        data_to_embed,
    )
