from matplotlib import pyplot as plt
from sklearn import metrics
import torch
from torch import Tensor
from torch.nn import Module
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

from data_generation import Dataset
from ._util import prepare_model


@torch.no_grad()
def plot_roc_curve(
    dataset_name: str = "simple-graph-plus--n-100--k-2--max-cut-value-30",
    checkpoint_paths: list[str] = [
        "../logs/important/2023-10-10T16-33-01--simple-gnn--on--simple-graph-plus--n-100--k-2--max-cut-value-30/"
            "5-checkpoint--new-format.pt",
        "../logs/important/2024-02-22T15-15-05--simple-graph-plus--n-100--k-2--max-cut-value-30--"
            "supervised_imle-loss/2-checkpoint.pt",
        "../logs/important/2024-02-15T18-05-10--simple-graph-plus--n-100--k-2--max-cut-value-30--"
            "unsupervised_imle-loss/3-checkpoint.pt",
        "../logs/important/2024-04-11T13-10-00--simple-graph-plus--n-100--k-2--max-cut-value-30--"
            "direct_gradient-loss/2-checkpoint.pt",
    ],
    model_labels: list[str] = [
        "Supervised Binary Cross Entropy Loss",
        "Supervised I-MLE with Hamming Loss",
        "Self-Supervised I-MLE with Size of Cut Loss",
        "Self-Supervised with Direct Gradient",
    ],
):
    """
    Plots the ROC curve for the given models.

    Parameters:
    - `dataset_name`: The name of the dataset to evaluate the models on.
                      The dataset must be in the `data/` directory, in a `.pt` file of the same name.
    - `checkpoint_paths`: The paths to the models to evaluate.
    - `model_labels`: The label shown for each model.
                      The list must be the same length as `checkpoint_paths`.
    """
    dataset = Dataset.load(dataset_name)
    val_graphs = dataset.val_graphs

    false_positive_rates: list[Tensor] = []
    true_positive_rates: list[Tensor] = []
    areas_under_curve: list[float] = []

    for checkpoint_path in checkpoint_paths:
        model = prepare_model(checkpoint_path)

        edge_predictions, ground_truth_cuts = _run_model(model, val_graphs)

        current_model_fpr, current_model_tpr, _ = metrics.roc_curve(ground_truth_cuts, edge_predictions)
        current_model_auc = metrics.auc(current_model_fpr, current_model_tpr)

        false_positive_rates.append(current_model_fpr)
        true_positive_rates.append(current_model_tpr)
        areas_under_curve.append(current_model_auc)

    _plot(false_positive_rates, true_positive_rates, areas_under_curve, model_labels, dataset.config.name)


@torch.no_grad()
def _run_model(model: Module, graphs: list[Data], batch_size: int = 64) -> tuple[Tensor, Tensor]:
    """
    Runs the model on the given graphs.

    Returns:
    - The edge predictions (logits) for all graphs. Size `[total_num_edges]`
    - The ground truth edge values for all graphs. Size `[total_num_edges]`
    """
    model.eval()
    val_loader = DataLoader(graphs, batch_size)

    edge_predictions = [model(graph) for graph in val_loader]
    ground_truth_cuts = [graph.y for graph in val_loader]

    return torch.cat(edge_predictions), torch.cat(ground_truth_cuts)


def _plot(
    false_positive_rates: list[Tensor],
    true_positive_rates: list[Tensor],
    areas_under_curve: list[float],
    model_labels: list[str],
    dataset_name: str,
):
    fig, ax = plt.subplots()
    fig.set_size_inches(9, 4)

    for fpr, tpr, auc, model_label in zip(false_positive_rates, true_positive_rates, areas_under_curve, model_labels):
        ax.plot(fpr, tpr, label=f"{model_label} (area under curve: {auc:.3f})")

    ax.set_xlim([0, 1])
    ax.set_ylim([0, 1])
    ax.plot([0, 1], [0, 1], linestyle="--", color="grey")
    ax.set_title(f"ROC Curves on {dataset_name}")
    ax.set_xlabel("False Positive Rate")
    ax.set_ylabel("True Positive Rate")
    ax.legend(loc="lower right")

    plt.savefig(f"roc-curves--on-{dataset_name}.svg")


if __name__ == "__main__":
    plot_roc_curve()
