from collections import defaultdict
import numpy as np
import matplotlib.pyplot as plt

PPI_convex_label = "Baselines"
ML_ALL_label = "Baselines + Ours"

KEY_TO_PLOT_CONFIG = {
    "NB_convex_lrt_growth": (ML_ALL_label, "tab:blue", "o"),
    "NB_convex_lrt_growth_positive_lambda": (
        f"{ML_ALL_label} with one-sided PPI",
        "tab:blue",
        "o",
    ),
    "PPI_convex": (PPI_convex_label, "tab:orange", "s"),
    "PPI_convex_positive_lambda": (
        f"{PPI_convex_label} with one-sided PPI",
        "tab:orange",
        "s",
    ),
    "NB_max_predictions_only": ("Ours", "tab:red", "X"),
    "NB_predictions_only": ("Ours", "tab:red", "X"),
    "lrt_y": ("$e^{Y}_{LR}$", "tab:green", "v"),
    "lrt_x": ("$e^{X}_{LR}$", "tab:purple", "<"),
    "cond_lrt": ("$e^{Y\\mid X}_{LR}$", "tab:cyan", "D"),
    "PPI": ("PPI", "tab:olive", "^"),
    "PPI_positive_lambda": ("One-Sided PPI", "cyan", "v"),
    "NB_convex_lrt_y_lrt_x": ("ML-LRT", "orange", "X"),
    "convex_lrt_x_y_growth_rate": ("LRTs", "brown", "<"),
    "NB_max_prod_predictions_only": ("Predictions Prod", "cyan", "p"),
    "PPI_no_unlabeled": ("PPI without unlabeled", "tab:blue", "X"),
    "label_shift_high": ("Label Shift N=100", "red", "o"),
    "label_shift_low": ("Label Shift N=10", "green", "s"),
    "concept_shift": ("Concept Shift", "blue", "D"),
    "M=2": ("M=2", "tab:blue", "o"),
    "M=5": ("M=5", "tab:orange", "s"),
    "M=16": ("M=16", "tab:green", "D"),
    "M=32": ("M=32", "tab:red", "^"),
    "M=128": ("M=128", "tab:purple", "v"),
    "alpha=0.05": ("$\\alpha=0.05$", "black", "p"),
}


def plot_multiple_curves(
    y_lists,
    x,
    y_err,
    keys,
    xlabel="X",
    ylabel="Y",
    title=None,
    show=True,
    save_path=None,
    ylim=(0.0, 1.0),
    show_legend=True,
):
    """
    Plot  lists as functions of a shared x-axis. Y values are constrained to [0,1].

    Parameters
    ----------
    y_lists : list[list[float]] | tuple
        lists/arrays of y-values.
    x : list[float] | array-like
        X-axis values shared by all series (len must match each y).
    labels : list[str] | None
        Optional labels for legend.
    xlabel, ylabel, title : str
        Axis labels and plot title.
    clip_y : bool
        If True, clip Y into [0,1] before plotting (doesn't modify inputs).
    show : bool
        Whether to call plt.show() at the end.
    save_path : str | None
        If given, saves the plot to this path.

    Returns
    -------
    (fig, ax)
    """

    fig, ax = plt.subplots()

    for i, y in enumerate(y_lists):
        key = keys[i]
        (lab, color, marker) = KEY_TO_PLOT_CONFIG[key]
        ax.plot(
            x,
            y,
            marker=marker,
            linewidth=1,
            label=lab,
            markersize=6,
            color=color,
        )
        if y_err is not None:
            err = y_err[i]
            y_upper = [y_val + err_val for y_val, err_val in zip(y, err)]
            y_lower = [y_val - err_val for y_val, err_val in zip(y, err)]
            ax.fill_between(x, y_lower, y_upper, alpha=0.2, color=color)

    ax.set_xlabel(xlabel, fontsize=40)
    ax.set_ylabel(ylabel, fontsize=40)
    ax.set_ylim(ylim)
    if title:
        ax.set_title(title, fontsize=40)
    ax.grid(True, linestyle="--", alpha=0.5)
    if show_legend:
        ax.legend(fontsize=20)
    ax.tick_params(axis="both", which="major", labelsize=25)

    if save_path:
        fig.savefig(save_path, bbox_inches="tight", dpi=300)
    if show:
        plt.show()

    return fig, ax


def save_legend(axes, save_path, font_size=20):
    lines = []
    labels = []
    seen = set()

    for ax in axes:
        ax_lines, ax_labels = ax.get_legend_handles_labels()
        for line, label in zip(ax_lines, ax_labels):
            if label not in seen:
                seen.add(label)
                lines.append(line)
                labels.append(label)

    if not lines:
        print("No legend entries found.")
        return

    # Create a new figure for the legend
    # Estimate width based on labels length, this is a heuristic
    fig_legend = plt.figure(figsize=(len(labels) * 3, 1.5))
    ax_legend = fig_legend.add_subplot(111)
    ax_legend.axis("off")

    # Create the legend on this separate figure
    legend = ax_legend.legend(
        lines,
        labels,
        loc="center",
        ncol=len(labels),
        frameon=True,
        fontsize=font_size,
        edgecolor="black",
        fancybox=False,
    )

    fig_legend.canvas.draw()
    bbox = (
        legend.get_window_extent()
        .transformed(fig_legend.dpi_scale_trans.inverted())
        .expanded(1.1, 1.1)
    )

    fig_legend.savefig(save_path, bbox_inches=bbox)
    plt.show()
    print(f"Saved legend to {save_path}")


def get_power_vs_steps(statistic_to_rejections):
    steps = range(1, 500, 20)  # [20, 50, 100, 150, 200, 250, 300, 350, 400, 450, 500]
    statistic_to_power_over_time = defaultdict(list)
    statistic_to_power_stderr_over_time = defaultdict(list)
    for step in steps:
        for statistic, first_rejection_steps in statistic_to_rejections.items():
            tmp = np.array(first_rejection_steps, dtype=float)
            power = np.mean(tmp <= step)
            stderr = np.std((tmp <= step).astype(float), ddof=1) / np.sqrt(
                len(first_rejection_steps)
            )
            statistic_to_power_over_time[statistic].append(power)
            statistic_to_power_stderr_over_time[statistic].append(stderr)

    return steps, statistic_to_power_over_time, statistic_to_power_stderr_over_time


def plot_power_vs_steps(statistic_to_rejections, title=None, save_path=None):
    steps, statistic_to_power_over_time, statistic_to_power_stderr_over_time = (
        get_power_vs_steps(statistic_to_rejections)
    )
    pred_name = (
        "NB_max_predictions_only"
        if "NB_max_predictions_only" in statistic_to_power_over_time
        else "NB_predictions_only"
    )

    key_to_label = {
        "lrt_y": "LRT-Y",
        "lrt_x": "LRT-X",
        "PPI": "PPI",
        "PPI_positive_lambda": "PPI lambda>0",
        "PPI_convex": "PPI_LRT",
        "PPI_convex_positive_lambda": "PPI_LRT lambda>0",
        "NB_convex_lrt_y_lrt_x": "ML-LRT",
        "NB_convex_lrt_growth": "ML-ALL",
        "convex_lrt_x_y_growth_rate": "LRTs",
        "NB_max_predictions_only": "Predictions",
        "NB_max_prod_predictions_only": "Predictions Prod",
        "NB_convex_lrt_growth_positive_lambda": "ML-ALL lambda>0",
        "PPI_convex_positive_lambda": "PPI_LRT lambda>0",
        "PPI_no_unlabeled": "PPI without unlabeled",
    }
    y_lists = [statistic_to_power_over_time[key] for key in key_to_label]
    y_lists_err = [statistic_to_power_stderr_over_time[key] for key in key_to_label]

    return plot_multiple_curves(
        y_lists,
        steps,
        y_err=y_lists_err,
        keys=list(key_to_label.keys()),
        xlabel="# steps",
        ylabel="Power",
        title=title,
        show=True,
        save_path=save_path,
    )


def plot_power_vs_steps_concept_shift(
    statistic_to_rejections, title=None, save_path=None
):
    steps, statistic_to_power_over_time, statistic_to_power_stderr_over_time = (
        get_power_vs_steps(statistic_to_rejections)
    )
    pred_name = (
        "NB_max_predictions_only"
        if "NB_max_predictions_only" in statistic_to_power_over_time
        else "NB_predictions_only"
    )

    key_to_label = {
        "lrt_y": "LRT-Y",
        "cond_lrt": "COND-LRT",
        # "convex_lrt_y_cond_lrt_growth_rate": "LRTs",
        "PPI": "PPI",
        "PPI_convex": "PPI_LRT",
        "NB_convex_lrt_growth": "ML-ALL",
        # "NB_convex_lrt_y_cond_lrt": "ML-LRT",
        "NB_predictions_only": "Predictions",
        "NB_prod_predictions_only": "Predictions Prod",
    }

    y_lists = [statistic_to_power_over_time[key] for key in key_to_label]
    y_lists_err = [statistic_to_power_stderr_over_time[key] for key in key_to_label]

    return plot_multiple_curves(
        y_lists,
        steps,
        y_err=y_lists_err,
        keys=list(key_to_label.keys()),
        xlabel="# steps",
        ylabel="Power",
        title=title,
        show=True,
        save_path=save_path,
    )


def plot_power_vs_steps_ppi_compared_labeled_data(
    statistic_to_rejections, title=None, save_path=None
):
    steps, statistic_to_power_over_time, statistic_to_power_stderr_over_time = (
        get_power_vs_steps(statistic_to_rejections)
    )

    pred_name = (
        "NB_max_predictions_only"
        if "NB_max_predictions_only" in statistic_to_power_over_time
        else "NB_predictions_only"
    )
    key_to_label = {
        "PPI": "PPI",
        # "lrt_y": "LRT-Y",
        # "PPI_positive_lambda": "PPI $\lambda>0$",
        # "PPI_convex": "PPI_LRT",
        # "PPI_convex_positive_lambda": "Baselines $\lambda>0$",
        # "NB_convex_lrt_growth": "ML-ALL",
        # "NB_convex_lrt_growth_positive_lambda": "Baselines + Ours $\lambda>0$",
        "PPI_no_unlabeled": "PPI $\epsilon=0$",
        # "NB_convex_lrt_y_lrt_x": "ML-LRT",
        # "convex_lrt_x_y_growth_rate": "LRTs",
        # pred_name: "Our Method",
        # "NB_max_prod_predictions_only": "Predictions Prod",
    }

    y_lists = [statistic_to_power_over_time[key] for key in key_to_label]
    y_lists_err = [statistic_to_power_stderr_over_time[key] for key in key_to_label]

    return plot_multiple_curves(
        y_lists,
        steps,
        y_err=y_lists_err,
        keys=list(key_to_label.keys()),
        xlabel="# steps",
        ylabel="Power",
        title=title,
        show=True,
        save_path=save_path,
        show_legend=False,
    )


def plot_power_vs_steps_one_sided_ppi(
    statistic_to_rejections, title=None, save_path=None
):
    steps, statistic_to_power_over_time, statistic_to_power_stderr_over_time = (
        get_power_vs_steps(statistic_to_rejections)
    )

    pred_name = (
        "NB_max_predictions_only"
        if "NB_max_predictions_only" in statistic_to_power_over_time
        else "NB_predictions_only"
    )
    key_to_label = {
        "NB_convex_lrt_growth_positive_lambda": "Baselines + Ours",
        "PPI_convex_positive_lambda": "Baselines",
        "PPI": "PPI",
        # "lrt_y": "LRT-Y",
        "PPI_positive_lambda": "One Sided PPI",
        # "PPI_convex": "PPI_LRT",
        # "NB_convex_lrt_growth": "ML-ALL",
        # "PPI_no_unlabeled": "PPI $\epsilon=0$",
        # "NB_convex_lrt_y_lrt_x": "ML-LRT",
        # "convex_lrt_x_y_growth_rate": "LRTs",
        pred_name: "Our Method",
        # "NB_max_prod_predictions_only": "Predictions Prod",
    }

    y_lists = [statistic_to_power_over_time[key] for key in key_to_label]
    y_lists_err = [statistic_to_power_stderr_over_time[key] for key in key_to_label]

    return plot_multiple_curves(
        y_lists,
        steps,
        y_err=y_lists_err,
        keys=list(key_to_label.keys()),
        xlabel="# steps",
        ylabel="Power",
        title=title,
        show=True,
        save_path=save_path,
        show_legend=False,
    )


def plot_power_vs_steps_paper(statistic_to_rejections, title=None, save_path=None):
    steps, statistic_to_power_over_time, statistic_to_power_stderr_over_time = (
        get_power_vs_steps(statistic_to_rejections)
    )

    pred_e_process_name = "\\breve{e}"
    # label_shift
    if "NB_max_predictions_only" in statistic_to_power_over_time:
        PPI_convex_label = "$\\text{conv}(e^{Y}_{LR}, e^{X}_{LR}, e_{PPI})$"
        ML_ALL_label = (
            "$\\text{conv}(e^{Y}_{LR}, e^{X}_{LR}, e_{PPI}, "
            + pred_e_process_name
            + ")$"
        )
        single_component_keys = ["lrt_y", "lrt_x", "PPI", "NB_max_predictions_only"]
    else:  # concept shift
        PPI_convex_label = "$\\text{conv}(e^{Y}_{LR}, e^{Y|X}_{LR},  e_{PPI})$"
        ML_ALL_label = (
            "$\\text{conv}(e^{Y}_{LR}, e^{Y|X}_{LR}, e_{PPI}, "
            + pred_e_process_name
            + ")$"
        )
        single_component_keys = ["lrt_y", "cond_lrt", "PPI", "NB_predictions_only"]

    PPI_convex_label = "Baselines"
    ML_ALL_label = "Baselines + Ours"

    key_to_label = {
        "lrt_y": "$e^{Y}_{LR}$",
        "lrt_x": "$e^{X}_{LR}$",
        "cond_lrt": "$e^{Y\\mid X}_{LR}$",
        "PPI": "PPI",
        "PPI_positive_lambda": "PPI lambda>0",
        "PPI_convex": PPI_convex_label,
        "PPI_convex_positive_lambda": "PPI_LRT lambda>0",
        "NB_convex_lrt_y_lrt_x": "ML-LRT",
        "NB_convex_lrt_growth": ML_ALL_label,
        "convex_lrt_x_y_growth_rate": "LRTs",
        # "NB_max_predictions_only": "$" + pred_e_process_name + "$",
        "NB_max_predictions_only": "Ours",
        # "NB_predictions_only": "$" + pred_e_process_name + "$",
        "NB_predictions_only": "Ours",
        "NB_max_prod_predictions_only": "Predictions Prod",
        "NB_convex_lrt_growth_positive_lambda": "ML-ALL lambda>0",
        "PPI_convex_positive_lambda": "PPI_LRT lambda>0",
        "PPI_no_unlabeled": "PPI without unlabeled",
    }

    ml_martingale = (
        "NB_max_predictions_only"
        if "NB_max_predictions_only" in statistic_to_power_over_time
        else "NB_predictions_only"
    )
    single_component_keys = (
        ["lrt_y", "lrt_x", "PPI"]
        if "NB_max_predictions_only" in statistic_to_power_over_time
        else ["lrt_y", "cond_lrt", "PPI"]
    )
    powers = {
        key: statistic_to_power_over_time[key][-1]
        for key in statistic_to_power_over_time
    }

    max_test = max(single_component_keys, key=powers.get)

    baselines = (
        ["lrt_y", "lrt_x"]
        if "NB_max_predictions_only" in statistic_to_power_over_time
        else ["lrt_y", "cond_lrt"]
    )

    # statistics_to_plot = ["NB_convex_lrt_growth", "PPI_convex", max_test, ml_martingale]
    statistics_to_plot = [
        "NB_convex_lrt_growth",
        "PPI_convex",
        ml_martingale,
    ] + baselines
    y_lists = [statistic_to_power_over_time[key] for key in statistics_to_plot]
    y_lists_err = [
        statistic_to_power_stderr_over_time[key] for key in statistics_to_plot
    ]

    return plot_multiple_curves(
        y_lists,
        steps,
        y_err=y_lists_err,
        keys=statistics_to_plot,
        xlabel="Step",
        ylabel="Power",
        title=title,
        show=True,
        save_path=save_path,
        show_legend=False,
    )


def plot_power_vs_steps_validity(statistic_to_rejections, title=None, save_path=None):
    steps, statistic_to_power_over_time, statistic_to_power_stderr_over_time = (
        get_power_vs_steps(statistic_to_rejections)
    )

    key_to_label = {
        "label_shift_high": "Label Shift N=100",
        "label_shift_low": "Label Shift N=10",
        "concept_shift": "Concept Shift",
    }

    valid_keys = [k for k in key_to_label if k in statistic_to_power_over_time]

    y_lists = [statistic_to_power_over_time[key] for key in valid_keys]
    y_lists_err = [statistic_to_power_stderr_over_time[key] for key in valid_keys]

    # Add horizontal line for alpha=0.05
    y_lists.append([0.05] * len(steps))
    y_lists_err.append([0.0] * len(steps))
    valid_keys.append("alpha=0.05")

    return plot_multiple_curves(
        y_lists,
        steps,
        y_err=y_lists_err,
        keys=valid_keys,
        xlabel="Step",
        ylabel="Power",
        title=title,
        show=True,
        save_path=save_path,
        ylim=(0, 0.1),
    )


def plot_hyperparameter_tuning(statistic_to_rejections, title=None, save_path=None):
    steps, statistic_to_power_over_time, statistic_to_power_stderr_over_time = (
        get_power_vs_steps(statistic_to_rejections)
    )

    y_lists = [statistic_to_power_over_time[key] for key in statistic_to_rejections]
    y_lists_err = [
        statistic_to_power_stderr_over_time[key] for key in statistic_to_rejections
    ]

    return plot_multiple_curves(
        y_lists,
        steps,
        y_err=y_lists_err,
        keys=list(statistic_to_rejections.keys()),
        xlabel="Step",
        ylabel="Power",
        title=title,
        show=True,
        save_path=save_path,
    )


def plot_power_vs_steps_baselines(statistic_to_rejections, title=None, save_path=None):
    steps, statistic_to_power_over_time, statistic_to_power_stderr_over_time = (
        get_power_vs_steps(statistic_to_rejections)
    )

    key_to_label = {
        "lrt_y": "$e^{Y}_{LR}$",
        "lrt_x": "$e^{X}_{LR}$",
        "cond_lrt": "$e^{Y\\mid X}_{LR}$",
        "PPI": "PPI",
        "NB_max_predictions_only": "Ours",
        "NB_predictions_only": "Ours",
    }

    y_lists = [
        statistic_to_power_over_time[key]
        for key in key_to_label
        if key in statistic_to_power_over_time
    ]
    y_lists_err = [
        statistic_to_power_stderr_over_time[key]
        for key in key_to_label
        if key in statistic_to_power_over_time
    ]

    return plot_multiple_curves(
        y_lists,
        steps,
        y_err=y_lists_err,
        keys=list(key_to_label.keys()),
        xlabel="Step",
        ylabel="Power",
        title=title,
        show=True,
        save_path=save_path,
        show_legend=False,
    )
