import pickle
import numpy as np
import pandas as pd
from sklearn.metrics import roc_curve

import os
import statsmodels.api as sm
import seaborn as sns
import matplotlib.pyplot as plt

sns.set(style="whitegrid")
from tueplots import bundles
import matplotlib.pyplot as plt

plt.rcParams.update(bundles.neurips2024())
from tueplots.figsizes import neurips2024


def _load_data(path):
    """
    Load data from disk for a specific configuration.
    Warning: This expects that scores and y_true have the same order as N.
    """

    with open(path, "rb") as f:
        in_indicies = pickle.load(f)

    M = len(in_indicies)
    N = len(in_indicies[0])
    S = path.split("_")[9]

    print(f"Loading {path}")
    print(f"M={M} N={N}")

    with open(path.replace("in_indices", "scores"), "rb") as f:
        scores_dict = pickle.load(f)

    if scores_dict["y_true"].shape == (M, N) and scores_dict["scores"].shape == (M, N):
        return in_indicies, scores_dict["scores"], scores_dict["y_true"], M, N
    else:
        scores = np.empty((M, N))
        y_true = np.empty((M, N))

        for m in range(M):
            y_true[m] = scores_dict["y_true"][m * N : (m + 1) * N]
            scores[m] = scores_dict["scores"][m * N : (m + 1) * N]

        assert np.all(np.hstack(y_true) == scores_dict["y_true"])
        assert np.all(np.hstack(scores) == scores_dict["scores"])

        return in_indicies, scores, y_true, M, N, S, path


def _compute_tpr_at_fpr(y_true, scores):
    """
    Takes y_true and scores and returns tpr at fpr for
    fpr in [1e-3, 1e-2, 1e-1].
    """
    fpr, tpr, _ = roc_curve(y_true=y_true, y_score=scores)
    fpr_points = [1e-3, 1e-2, 1e-1, 0.5]
    mia_results = np.interp(x=fpr_points, xp=fpr, fp=tpr)
    return mia_results


def generate_vector_space_tpr_data(mia_path: str):
    """
    Create a df that contains distances in feature space and tpr at fprs.
    """
    _, scores, y_true, M, N, S, path = _load_data(mia_path)

    results = np.zeros((N, 4))

    for datapoint in range(N):
        tpr_at_fpr = _compute_tpr_at_fpr(y_true=y_true[:, datapoint], scores=scores[:, datapoint])
        results[datapoint] = tpr_at_fpr

    df = pd.DataFrame(results, columns=[1e-3, 1e-2, 1e-1, 0.5])
    # df = pd.DataFrame(results, columns=[1e-3, 1e-2, 1e-1, "label"])
    df["M"] = M
    df["N"] = N
    df["S"] = S
    df["path"] = path
    return df


from matplotlib.ticker import FormatStrFormatter


def set_shot_axis(ax):
    ax.set_xscale("log")
    ax.set_yscale("log")
    ax.set_xticks([2**i for i in range(4, 17, 2)])
    ax.set_xticklabels([2**i for i in range(4, 17, 2)])
    ax.set_yticks([1, 0.1, 0.01, 0.001])
    ax.set_xlabel(f"$S$ (shots)")
    ax.yaxis.set_major_formatter(FormatStrFormatter("%g"))


def add_box_ticks(axes):
    for i_ax, ax in enumerate(axes):
        ax.tick_params(bottom=True, left=i_ax == 0, length=3)
        for _, spine in ax.spines.items():
            spine.set_visible(True)  # You have to first turn them on
            spine.set_color("black")
            spine.set_linewidth(1)


def add_legend_below(ax, order=None, x_pos=0.5):
    # Shrink current axis's height by 10% on the bottom
    box = ax.get_position()
    ax.set_position([box.x0, box.y0 + box.height * 0.1, box.width, box.height * 0.9])

    # Put a legend below current axis
    handles, labels = ax.get_legend_handles_labels()
    if order is not None:
        handles = [handles[idx] for idx in order]
        labels = [labels[idx] for idx in order]
    ax.legend(
        handles=handles,
        labels=labels,
        loc="upper center",
        bbox_to_anchor=(x_pos, -0.2),
        fancybox=True,
        shadow=False,
        ncol=6,
        columnspacing=0.8,
    )


def errorbar_min_max(v):
    return [min(v), max(v)]


def plot(df_new, fpr, ylims):
    data = (
        df_new.groupby(["path"])
        .agg(
            {
                fpr: [
                    "mean",
                    "max",
                    lambda x: np.quantile(x, 0.999),
                    lambda x: np.quantile(x, 0.99),
                    lambda x: np.quantile(x, 0.95),
                ],
                "S": ["max"],
                "N": ["max"],
            }
        )
        .reset_index(drop=True)
    )
    data.columns = ["mean", "Max", "0.999 Quantile", "0.99 Quantile", "0.95 Quantile", "S", "N"]

    df_plot = data.copy()
    df_plot["S"] = df_plot["S"].astype(float) / 2

    fig, orig_axes = plt.subplots(figsize=fig_size, nrows=1, ncols=2, gridspec_kw={"width_ratios": [1, 1]})
    axes = orig_axes[0]

    for c in ["Max", "0.999 Quantile", "0.99 Quantile", "0.95 Quantile"]:
        sns.lineplot(df_plot, x="S", y=c, estimator="median", label=c, ax=axes, marker="o", errorbar=errorbar_min_max)

    set_shot_axis(axes)
    add_box_ticks([axes])
    axes.set_ylabel(f"TPR-FPR\nat FPR={fpr}")
    axes.set_yticks([1, 0.8, 0.6, 0.4, 0.2, 0.05])
    axes.set_xticks([2**i for i in [14, 14.5, 15, 15.5, 16]])
    axes.set_xticklabels([int(2**14), "", int(2**15), "", int(2**16)])
    axes.set_xlim([2**14, 2**16])
    axes.minorticks_off()
    axes.set_ylim(ylims)

    # Put a legend to the right of the current axis
    fig.legend(loc="center left", bbox_to_anchor=(0.6, 0.55))
    fig.delaxes(orig_axes[1])
    axes.legend_.remove()
    return fig


def errorbar_min_max(v):
    return [min(v), max(v)]


def plot(df_new, fpr, ylims):
    fig_size = neurips2024()["figure.figsize"]
    fig_size = (fig_size[0] * 0.5, 0.75 * fig_size[1])
    data = (
        df_new.groupby(["path"])
        .agg(
            {
                fpr: [
                    "mean",
                    "max",
                    lambda x: np.quantile(x, 0.999),
                    lambda x: np.quantile(x, 0.99),
                    lambda x: np.quantile(x, 0.95),
                ],
                "S": ["max"],
                "N": ["max"],
            }
        )
        .reset_index(drop=True)
    )
    data.columns = ["mean", "Max", "0.999 Quantile", "0.99 Quantile", "0.95 Quantile", "S", "N"]

    df_plot = data.copy()
    df_plot["S"] = df_plot["S"].astype(float) / 2

    fig, orig_axes = plt.subplots(figsize=fig_size, nrows=1, ncols=2, gridspec_kw={"width_ratios": [1, 1]})
    axes = orig_axes[0]

    for c in ["Max", "0.999 Quantile", "0.99 Quantile", "0.95 Quantile"]:
        sns.lineplot(df_plot, x="S", y=c, estimator="median", label=c, ax=axes, marker="o", errorbar=errorbar_min_max)

    set_shot_axis(axes)
    add_box_ticks([axes])
    axes.set_ylabel(f"TPR-FPR\nat FPR={fpr}")
    axes.set_yticks([1, 0.8, 0.6, 0.4, 0.2, 0.05])
    axes.set_xticks([2**i for i in [14, 14.5, 15, 15.5, 16]])
    axes.set_xticklabels([int(2**14), "", int(2**15), "", int(2**16)])
    axes.set_xlim([2**14, 2**16])
    axes.minorticks_off()
    axes.set_ylim(ylims)

    # Put a legend to the right of the current axis
    fig.legend(loc="center left", bbox_to_anchor=(0.6, 0.55))
    fig.delaxes(orig_axes[1])
    axes.legend_.remove()
    return fig


def fit_model(df, fpr, y_column):
    fitting_data = (
        df.groupby(["path"])
        .agg(
            {
                str(fpr): [
                    "mean",
                    "max",
                    lambda x: np.quantile(x, 0.999),
                    lambda x: np.quantile(x, 0.99),
                    lambda x: np.quantile(x, 0.95),
                ],
                "S": ["max"],
                "N": ["max"],
            }
        )
        .reset_index(drop=True)
    )
    fitting_data.columns = ["mean", "Max", "0.999 Quantile", "0.99 Quantile", "0.95 Quantile", "S", "N"]
    X = fitting_data[["S"]].astype(float)
    X["S"] = X["S"].apply(lambda x: np.log10(x))
    X = sm.add_constant(X)
    y = fitting_data[y_column].apply(lambda x: np.log10(x))

    model = sm.OLS(y, X)
    results = model.fit()
    return results.summary()


if __name__ == "__main__":
    dir_path = "SET PATH HERE"
    all_files = os.listdir(dir_path)
    all_files = [f.replace("scores", "in_indices") for f in all_files if "scores_" in f]
    all_files = sorted(all_files)
    dfs = []
    for file in all_files:
        df = generate_vector_space_tpr_data(mia_path=dir_path + file)
        dfs.append(df)

    df = pd.concat(dfs)

    df.to_csv("cached_tpr_fpr_individual.csv", index=False)

    # Load the data and create table
    df_new = df.copy()
    df_new["0.5"] = df_new["0.5"] - 0.5
    df_new["0.1"] = df_new["0.1"] - 0.1

    grouped = (
        df_new[(df_new["S"] > 32000) & (df_new["S"] != 49152)]
        .groupby(["path"])
        .agg(
            {
                str(0.1): [
                    "max",
                    lambda x: np.quantile(x, 0.999),
                    lambda x: np.quantile(x, 0.99),
                    lambda x: np.quantile(x, 0.95),
                ],
                "S": ["max"],
            }
        )
        .reset_index(drop=True)
    )

    grouped["S"] = grouped["S"].astype(int) // 2
    grouped = grouped.droplevel(1, axis=1)
    grouped.columns = ["Max", "0.999 Quantile", "0.99 Quantile", "0.95 Quantile", "S"]
    grouped = grouped.groupby(by=["S"]).agg("median").reset_index()
    string_tex = grouped.to_latex(float_format="%.2f", index=False)

    for r in string_tex.split("\n"):
        print(r)

    # plot the data

    fig = plot(df_new[(df_new["S"] != 49152) & (df_new["S"] > 32000)], fpr="0.1", ylims=[0.05, 1])
    fig.savefig("individual_tpr_fpr_0.1.pdf", dpi=300, bbox_inches="tight")

    # fit the model
    for fpr in [0.1]:
        for quantile in ["Max", "0.999 Quantile", "0.99 Quantile", "0.95 Quantile"]:
            df_new["S"] = df_new["S"].astype(float)
            result_summary = fit_model(df_new[(df_new["S"] > 34000) & (df_new["S"] != 49152)], fpr, quantile)
            print(f"fpr: {fpr} quantile: {quantile}")
            print(result_summary)
            with open("output.txt", "a") as f:
                f.write(f"fpr: {fpr} quantile: {quantile}\n")
                f.write(str(result_summary))
                f.write("\n\n")
