import numpy as np
import seaborn as sns
from matplotlib.axes._axes import Axes
from matplotlib.ticker import MaxNLocator


def reference_scatter(
    ax: Axes,
    posterior_draws: np.ndarray,
    accurate_posterior: np.ndarray,
    no_sc_posterior=None,
    posterior_label="posterior",
    accurate_label="reference",
    title="",
):
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.set_title(title)

    ax.scatter(
        accurate_posterior[:, 0],
        accurate_posterior[:, 1],
        label=accurate_label,
        color="grey",
        alpha=0.8,
        s=0.5,
    )

    ax.scatter(
        posterior_draws[:, 0],
        posterior_draws[:, 1],
        label=posterior_label,
        color="#001842",
        alpha=0.8,
        s=0.5,
    )

    if no_sc_posterior is not None:
        ax.scatter(
            no_sc_posterior[:, 0],
            no_sc_posterior[:, 1],
            label="no_sc_posterior",
            s=1,
            alpha=0.8,
            color="#b33b4b",
            zorder=0,
        )

    return ax


def reference_contour(
    ax: Axes,
    posterior_draws: np.ndarray,
    accurate_posterior: np.ndarray,
    no_sc_posterior=None,
    posterior_label="posterior",
    accurate_label="reference",
    title="",
    title_fontsize=12,
):
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.set_title(title, fontsize=title_fontsize)

    sns.kdeplot(
        x=accurate_posterior[:, 0],
        y=accurate_posterior[:, 1],
        ax=ax,
        color="grey",
        levels=6,
        linewidths=1,
        alpha=0.9,
        label=accurate_label,
    )

    sns.kdeplot(
        x=posterior_draws[:, 0],
        y=posterior_draws[:, 1],
        ax=ax,
        label=posterior_label,
        color="#2a63bb",
        levels=6,
        alpha=0.9,
        linewidths=1,
    )

    ax.scatter(
        np.mean(accurate_posterior[:, 0]),
        np.mean(accurate_posterior[:, 1]),
        label="no_sc_posterior",
        color="dimgrey",
        marker="v",
        zorder=120,
        s=20,
    )

    ax.scatter(
        np.mean(posterior_draws[:, 0]),
        np.mean(posterior_draws[:, 1]),
        label="no_sc_posterior",
        color="navy",
        marker="s",
        zorder=100,
        s=20,
    )

    if no_sc_posterior is not None:
        ax.scatter(
            np.mean(no_sc_posterior[:, 0]),
            np.mean(no_sc_posterior[:, 1]),
            label="no_sc_posterior",
            color="#7b001b",
            marker="+",
            zorder=5,
            s=20,
        )

    if no_sc_posterior is not None and np.mean(accurate_posterior[:, 0]) < 1.6:
        sns.kdeplot(
            x=no_sc_posterior[:, 0],
            y=no_sc_posterior[:, 1],
            ax=ax,
            label="no_sc_posterior",
            levels=6,
            alpha=1,
            color="#b33b4b",
            linewidths=1,
        )

    return ax


def comparison_lines(
    ax: Axes,
    y_values: list[list],
    std_deviations: list[list],
    title="",
    ylabel="",
    labels=None,
    colors=None,
    linestyles=None,
    offsets=None,
    ylabel_fontsize=18,
    xlabel_fontsize=17,
    title_fontsize=18,
    tick_labelsize=14,
    linewidth=1,
):
    if labels is None:
        labels = [i for i in range(len(y_values))]

    if colors is None:
        colors = [None for _ in range(len(y_values))]

    if linestyles is None:
        linestyles = ["solid" for _ in range(len(y_values))]

    if offsets is None:
        offsets = [0 for _ in range(len(y_values))]

    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)

    ax.set_ylabel(ylabel, fontsize=ylabel_fontsize)
    ax.set_xlabel("observed data location", fontsize=xlabel_fontsize)
    ax.set_title(title, fontsize=title_fontsize)
    ax.tick_params(axis="both", labelsize=tick_labelsize)

    ax.xaxis.set_major_locator(MaxNLocator(integer=True))

    for i, y in enumerate(y_values):
        num_conditions = len(y)

        ax.errorbar(
            [x + offsets[i] for x in range(num_conditions)],
            y,
            yerr=std_deviations[i],
            capsize=2,
            label=labels[i],
            linestyle=linestyles[i],
            color=colors[i],
            linewidth=linewidth,
        )

    ax.axhline(0, linestyle="dotted", color="black")

    return ax
