from collections import namedtuple
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
from datasets import load_dataset
from persim import plot_diagrams
from prodigyopt import Prodigy
from ripser import ripser
from sentence_transformers import SentenceTransformer
from sklearn.metrics import confusion_matrix
from torch.utils.data import DataLoader

from src.baseline_sep_metrics import (
    balanced_accuracy_index,
    kmeans_ch_score,
    roc_auc_index,
    thornton_separability_index,
)
from src.gpu_utils import accuracy_gpu_safe, memory_stats, roc_auc_gpu_safe
from src.st_classifier import STClassifier


def run_exp(
    shuffle_seed,
    estimator,
    dataset_name,
    model_name="minilm",
    norm=False,
    produce_plots=False,
    epochs: int = 6,
    d_coef: float = 1e-1,
    batch_size: int = 32,
    plot_persistence_diagrams: bool = False,
    normalise_h0: bool = True,
):
    """
    Main experiment runner for the multiclass-classification experiment.

    :param shuffle_seed: seed to use for shuffling the data.
    :param estimator: callable, should take as input a tensor of floats and give a float
    :param d_coef: prodigy-opt parameter
    :param epochs: int, self-explanatory
    :param model_name: str, shorthand for a sentence-transformer
    :param normalize: bool, whether to normalize the embeddings or not
    :param produce_plots: bool, whether to show plots
    :param batch_size: int, self-explanatory
    :returns: namedtuple with results from run
    """

    device = "cuda" if torch.cuda.is_available() else "cpu"

    dataset = load_and_prep_dataset(dataset_name)

    print(f"dataset size = {len(dataset)}")

    dataset = dataset.shuffle(shuffle_seed)
    dataset = dataset.train_test_split(
        train_size=2000, stratify_by_column="label", seed=42
    )

    train_dataset = dataset["train"]
    data_to_embed = dataset["test"]

    data_to_embed = data_to_embed.train_test_split(
        train_size=1000, stratify_by_column="label", seed=42
    )
    test_dataset = data_to_embed["test"]
    data_to_embed = data_to_embed["train"]

    print(train_dataset.shape)
    print(data_to_embed.shape)

    n_class = len(set(train_dataset["label"]))
    print(f"n_class = {n_class}")

    test_dataset = test_dataset.select(range(min(len(test_dataset), 500)))

    (
        train_dataloader,
        train_dataloader_scores,
        embed_dataloader,
        test_dataloader,
    ) = prep_dataloaders(batch_size, train_dataset, data_to_embed, test_dataset)

    embeddings = []

    #### setup classifier

    st_model = SentenceTransformer(model_name, device="cpu")
    classifier = STClassifier(
        model_name, st_model, device=device, n_class=n_class, norm=norm
    )

    print(f"model name: {model_name}")
    print(f"Dataset used: {dataset_name}")
    print(f"classifier device: {classifier.model.device}")

    with torch.no_grad():
        scores = torch.cat(
            [classifier(test_batch["text"]) for test_batch in test_dataloader],
            dim=0,
        )
        print(
            f"auc on test_dataset = {roc_auc_gpu_safe(test_dataset['label'], scores, multi_class='ovo'):1.2f}"
        )

        print(
            f"accuracy on the test_dataset = {accuracy_gpu_safe(test_dataset['label'], scores): 1.2f}"
        )

        scores = torch.cat(
            [classifier(test_batch["text"]) for test_batch in train_dataloader_scores],
            dim=0,
        )
        print(
            f"accuracy on the train_dataset = {accuracy_gpu_safe(train_dataset['label'], scores): 1.2f}"
        )

    #### balance out the loss according to class

    def calculate_class_weights(labels):
        _, counts = torch.unique(labels, return_counts=True)
        weights = 1.0 / counts.float()
        return weights.to(device)

    class_weights = calculate_class_weights(
        torch.tensor(train_dataset["label"], requires_grad=False)
    )
    loss_fn = nn.CrossEntropyLoss(weight=class_weights)

    optimizer = Prodigy(classifier.parameters(), d_coef=d_coef)

    classifier.eval()
    with torch.no_grad():
        embeddings.append(
            torch.cat(
                [classifier.encode(ebatch["text"]) for ebatch in embed_dataloader],
                dim=0,
            ).to("cpu")
        )

    classifier.train()

    memory_stats()

    # Train the STClassifier on the training set and embed the "data_to_embed" set at
    # every epoch
    loss_iter = []
    print("Training ... ")
    for _ in range(epochs):
        for i, batch in enumerate(train_dataloader):
            inputs = batch["text"]
            labels = batch["label"].to(device)

            scores = classifier(inputs)
            loss = loss_fn(scores, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        classifier.eval()

        with torch.no_grad():
            loss_iter.append(loss.item())
            embeddings.append(
                torch.cat(
                    [classifier.encode(ebatch["text"]) for ebatch in embed_dataloader],
                    dim=0,
                ).to("cpu")
            )

        classifier.train()

    classifier.eval()
    with torch.no_grad():
        scores = torch.cat(
            [classifier(test_batch["text"]) for test_batch in test_dataloader],
            dim=0,
        )
        print(
            f"auc on test_dataset = {roc_auc_gpu_safe(test_dataset['label'], scores, multi_class='ovo'):1.2f}"
        )

        print(
            f"accuracy on the test_dataset = {accuracy_gpu_safe(test_dataset['label'], scores): 1.2f}"
        )

        y_labels = scores.argmax(dim=1).to("cpu")
        print(confusion_matrix(test_dataset["label"], y_labels))

        scores = torch.cat(
            [classifier(test_batch["text"]) for test_batch in train_dataloader_scores],
            dim=0,
        )
        print(
            f"accuracy on the train_dataset = {accuracy_gpu_safe(train_dataset['label'], scores): 1.2f}"
        )

    if produce_plots:
        plt.figure()
        sns.lineplot(x=range(epochs), y=loss_iter)
        plt.ylabel("cross-entropy loss")
        plt.show()

    print("Done with training. Now plots and metrics. ")

    stacked_embeddings = torch.stack(embeddings)
    y_true = data_to_embed["label"]

    N_TRIALS, _, _ = stacked_embeddings.shape

    thorntons = []
    kmeans_score = []
    roc_auc_score_list = []
    accuracy_list = []

    h0 = []
    prob_h0 = []

    if plot_persistence_diagrams:
        fig, axs = plt.subplots(1, N_TRIALS, figsize=(10, 10))

    for i in range(N_TRIALS):
        X = stacked_embeddings[i, :, :]

        # baselines
        roc_auc_value = roc_auc_index(X, y_true, n_splits=5, max_iter=1000)
        accuracy_score = balanced_accuracy_index(X, y_true, n_splits=5, max_iter=1000)
        accuracy_list.append((i, accuracy_score))

        roc_auc_score_list.append((i, roc_auc_value))
        kmeans_score.append(kmeans_ch_score(X, n_clusters=5))
        thorntons.append(thornton_separability_index(X, y_true))

        # ripser calculations
        diagrams = ripser(X)["dgms"]

        if normalise_h0:
            # skip the -1 element as it is always equal to np.inf
            diagrams[0][:, 1] /= diagrams[0][-2, 1]

        h0_diag_data = diagrams[0]

        if plot_persistence_diagrams:
            plot_diagrams(diagrams, plot_only=[0], ax=axs[i])

        h0_without_inf = h0_diag_data[:-1, 1]
        h0.append(h0_without_inf)

        stat = estimator(h0_without_inf)
        prob_h0.append(stat)

    scores = pd.DataFrame(
        {
            "Thornton_(supervised)": thorntons,
            "Calinski-Harabasz_(unsupervised)": kmeans_score / max(kmeans_score),
            "persistence_score": prob_h0,
        }
    )

    norm_scores = scores

    roc_auc_score_df = pd.DataFrame(
        roc_auc_score_list, columns=["epoch", "roc_auc_score"]
    ).explode("roc_auc_score")

    accuracy_score_df = pd.DataFrame(
        accuracy_list, columns=["epoch", "accuracy"]
    ).explode("accuracy")

    if produce_plots:
        sns.set_palette("bright")

        plt.tight_layout()

        norm_scores.plot(linestyle="-", marker="o", markersize=5)
        sns.lineplot(
            roc_auc_score_df,
            x="epoch",
            y="roc_auc_score",
            marker="o",
            markersize=5,
            label="ROC-AUC-5",
        )

        sns.lineplot(
            accuracy_score_df,
            x="epoch",
            y="accuracy",
            marker="o",
            markersize=5,
            label="accuracy-5",
        )

        plt.legend(loc=0, fancybox=True)
        plt.axhline(y=1, linestyle="--", color="k")
        plt.xlabel("epoch")
        plt.ylabel("scores")
        sns.despine()

        plt.xticks(range(N_TRIALS))
        plt.show()

    st_model = None

    to_return = namedtuple(
        "MultiClassExpResults",
        [
            "scores_df",
            "roc_auc_score_df",
            "accuracy_score_df",
            "h0",
            "stacked_embeddings",
            "original_data_to_embed",
        ],
    )

    return to_return(
        scores,
        roc_auc_score_df,
        accuracy_score_df,
        h0,
        stacked_embeddings,
        data_to_embed,
    )


def prep_dataloaders(batch_size, train_dataset, data_to_embed, test_dataset):
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    # we don't want those to be shuffled
    train_dataloader_scores = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=False
    )
    embed_dataloader = DataLoader(data_to_embed, batch_size=batch_size, shuffle=False)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return train_dataloader, train_dataloader_scores, embed_dataloader, test_dataloader


def load_and_prep_dataset(dataset_name):
    if dataset_name == "tweet_eval":
        dataset = load_dataset(dataset_name, "emotion", split="train")
    elif dataset_name == "financial_phrasebank":
        dataset = load_dataset(dataset_name, "sentences_75agree", split="train")
    else:
        dataset = load_dataset(dataset_name, split="train")
        dataset = dataset.class_encode_column("label")

    try:
        dataset = dataset.rename_column("sentence", "text")
    except ValueError:
        pass
    return dataset
