import itertools
import pytest
import numpy as np
from utils.plot import set_latex_font, add_legend
from utils.io import save_fig
import matplotlib.pyplot as plt
import scienceplots
from tueplots import bundles

plt.style.use(["science"])
plt.rcParams.update(bundles.neurips2024())


@pytest.mark.parametrize(
    "case",
    [
        1,
        2,
        3,
    ],
)
def test_rcl(out, case):
    t = 0.25

    def c(x):
        return -x * (1 - x) + x

    if case == 1:
        th = t
        th_label = r"$t^{\star}$"
    elif case == 2:
        th = np.sqrt(t)
        th_label = r"$t_f$"
    elif case == 3:
        th = t
        th_label = r"$t^{\star}$"

        def c(x):
            return x

    set_latex_font()
    fig, ax = plt.subplots(figsize=(2.1, 1.8))

    XX = np.linspace(0, 1, 100)

    CC = c(XX)

    ax.plot(XX, XX, color="black", ls="--", lw=0.5)
    ax.plot(XX, CC, color="black", label="$c$")

    ax.axvline(th, color="black", lw=0.5)
    ax.axhline(t, color="black", lw=0.5)  # , label="$t^{\star}$")

    ax.set(
        xlabel="$p$",
        ylabel="$c(p)$",
        xlim=(0, 1),
        ylim=(0, 1),
    )
    # ax.set_aspect("equal")

    XX2 = np.linspace(th, 1, 100)
    TT2 = np.full_like(XX2, t)
    CC2 = c(XX2)

    ax.fill_between(
        XX2,
        TT2,
        CC2,
        # where=np.logical_and(XX <= t, CC >= t),
        where=CC2 <= t,
        color="tab:red",
        linewidth=0,
        label=r"$R_{f\!,t^{\star}}^{\mathrm{CL}}$",
        # interpolate=True,
        # alpha=0.3,
        # edgecolor="none",
    )

    ax.fill_between(
        [0, th],
        [0, 0],
        [t, t],
        facecolor="tab:green",
        zorder=0,
        alpha=0.2,
        label="Agreement",
    )
    ax.fill_between([th, 1], [t, t], [1, 1], facecolor="tab:green", zorder=0, alpha=0.2)
    ax.fill_between(
        [0, th],
        [t, t],
        [1, 1],
        facecolor="tab:gray",
        zorder=0,
        alpha=0.2,
        label="Disagreement",
    )
    ax.fill_between([th, 1], [0, 0], [t, t], facecolor="tab:gray", zorder=0, alpha=0.2)

    # Add text annotation at (th/2, t + 0.1) with value "t"
    ax.text(t / 2, t, r"$t^{\star}$", ha="center", va="bottom", fontsize=8)
    # ax.text(th, t / 2, r"$t_f$", ha="right", va="center", fontsize=8)
    ax.text(
        th + 0.01,
        (1 - th) / 2 + th,
        th_label,
        ha="left",
        va="center",
        fontsize=8,
    )

    # if case == 2:
    # add_legend(ax)
    add_legend(ax)
    save_fig(fig, out, pad_inches=0, c=case)


def test_rcl2(out):
    t = 0.25

    def plot_case(ax, case):
        def c(x):
            return -x * (1 - x) + x

        if case == 1:
            th = t
            th_label = r"$t^{\star}$"
        elif case == 2:
            th = np.sqrt(t)
            th_label = r"$t_f$"
        elif case == 3:
            th = t
            th_label = r"$t^{\star}$"

            def c(x):
                return x

        XX = np.linspace(0, 1, 100)

        CC = c(XX)

        ax.plot(XX, XX, color="black", ls="--", lw=0.5)
        ax.plot(XX, CC, color="black", label="$c$")

        ax.axvline(th, color="black", lw=0.5)
        ax.axhline(t, color="black", lw=0.5)  # , label="$t^{\star}$")

        ax.set(
            xlabel="$p$",
            # xlabel="Probability estimate $p$",
            # ylabel="$c_f(p)$",
            xlim=(0, 1),
            ylim=(0, 1),
        )
        # ax.set_aspect("equal")

        # turn off yticklabels
        ticks = np.linspace(0, 1, 11)
        ax.set_yticks(ticks)
        ax.set_xticks(ticks)
        # ax.set_yticks([0, 0.5, 1])
        # ax.set_xticks([0, 0.5, 1])
        ticklabels = [""] * len(ticks)
        ticklabels[0] = "0"
        ticklabels[-1] = "1"
        # ax.set_xticklabels(["0", "1"])
        # ax.set_xticklabels(["0", r"$\tfrac{1}{2}$", "1"])
        ax.set_xticklabels(ticklabels)
        # ax.set_xticklabels([])
        if case != 1:
            ax.set_yticklabels([])
        else:
            ax.set_yticklabels(ticklabels)
            # ax.set_yticklabels(["0", r"$\tfrac{1}{2}$", "1"])
            ax.set(
                ylabel="$c(p)$",
            )
            # change position of ylabel
            ax.yaxis.set_label_coords(-0.05, 0.5)  # Adjust the position of the label

        ax.xaxis.set_label_coords(0.5, -0.03)  # Adjust the position of the label

        # disable minor ticks
        ax.minorticks_off()

        XX2 = np.linspace(th, 1, 100)
        TT2 = np.full_like(XX2, t)
        CC2 = c(XX2)

        ax.fill_between(
            XX2,
            TT2,
            CC2,
            # where=np.logical_and(XX <= t, CC >= t),
            where=CC2 <= t,
            color="tab:red",
            linewidth=0,
            label=r"$R_{f\!,t^{\star}}^{\mathrm{CL}}$",
            # interpolate=True,
            # alpha=0.3,
            # edgecolor="none",
        )

        ax.fill_between(
            [0, th],
            [0, 0],
            [t, t],
            facecolor="tab:green",
            zorder=0,
            alpha=0.2,
            label="Agreement",
        )
        ax.fill_between(
            [th, 1], [t, t], [1, 1], facecolor="tab:green", zorder=0, alpha=0.2
        )
        ax.fill_between(
            [0, th],
            [t, t],
            [1, 1],
            facecolor="tab:gray",
            zorder=0,
            alpha=0.2,
            label="Disagreement",
        )
        ax.fill_between(
            [th, 1], [0, 0], [t, t], facecolor="tab:gray", zorder=0, alpha=0.2
        )

        # Add text annotation at (th/2, t + 0.1) with value "t"
        ax.text(t / 2, t, r"$t^{\star}$", ha="center", va="bottom", fontsize=8)
        # ax.text(th, t / 2, r"$t_f$", ha="right", va="center", fontsize=8)
        ax.text(
            th + 0.01,
            (1 - th) / 2 + th,
            th_label,
            ha="left",
            va="center",
            fontsize=8,
        )

        # set fontsize of xticklabels to 8
        # ax.tick_params(axis="x", labelsize=5)
        # # ax.tick_params(axis="y", labelsize=5)
        # ax.set_xticks([0, 0.5, 1])
        # ax.set_xticklabels(["0", r"$\tfrac{1}{2}$", "1"])

    set_latex_font()
    fig, axes = plt.subplots(
        1, 3, figsize=(3, 0.9), gridspec_kw={"wspace": 0}, constrained_layout=False
    )
    # plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)

    for i, ax in enumerate(axes):
        # Set positions (x0, y0, width, height) of each subplot
        ax.set_position(
            [i * (1 / 3), 0, 1 / 3 - 0.02, 1]
        )  # Adjust the position manually

    # adjust wspace between subplots
    # plt.subplots_adjust(wspace=0.1)

    # plt.rc('legend', fontsize=14)
    plt.rc("legend", borderpad=0.2)
    # plt.rc("legend", borderaxespad=0.01)
    # plt.rc("legend", handlelength=1.6)
    # plt.rc("legend", labelspacing=0.2)
    plt.rc("legend", handletextpad=0.4)
    plt.rc("legend", columnspacing=1.5)

    ax1, ax2, ax3 = axes
    plot_case(ax1, 1)
    plot_case(ax2, 2)
    plot_case(ax3, 3)
    add_legend(ax2, ncols=4, dy=0.2)
    ax1.set_title(r"a. Nonzero $R^{\mathrm{CL}}_{f, t^{\star}}$", fontsize=8)
    # ax2.set_title("b. Threshold adjustment", fontsize=8)
    ax2.set_title(r"b. $t_f$ adjustment", fontsize=8)
    ax3.set_title("c. Recalibration", fontsize=8)
    # adapt vertical spacing of ax title
    ax1.title.set_position([0.5, 1.0])
    save_fig(fig, out, pad_inches=0)


def test_vmin(out):
    t = 0.25

    set_latex_font()
    fig, ax = plt.subplots(figsize=(2.1, 1.8))

    XX = np.linspace(0, 1, 1001)

    def vmin(c):
        M = (c >= t).astype(int)
        return c * (t - c) * (1 - M) + (1 - c) * (c - t) * M

    def vmax(c):
        return c * (1 - c)

    Vmin = vmin(XX)
    Vmax = vmax(XX)

    V = Vmax / 1.7

    ax.axvline(t, color="black", lw=0.5)
    ax.plot(XX, Vmin, label=r"$V_{\mathrm{min}}$", color="black", ls=":", lw=0.5)
    ax.plot(XX, Vmax, label=r"$V_{\mathrm{max}}$", color="black", ls="--", lw=0.5)
    # ax.plot(XX, V, label=r"$\mathbb{V}[F|H]$", color="black")

    ax.fill_between(
        XX,
        Vmin,
        Vmax,
        facecolor="tab:green",
        zorder=0,
        alpha=0.2,
        label=r"$L_f^{\mathrm{GL}}(p) > 0$",
    )
    ax.fill_between(
        XX,
        np.zeros_like(XX),
        Vmin,
        facecolor="tab:gray",
        zorder=0,
        alpha=0.2,
        label=r"$L_f^{\mathrm{GL}}(p) = 0$",
    )

    # ax.fill_between(
    #     XX,
    #     Vmin,
    #     V,
    #     # where=np.logical_and(XX <= t, CC >= t),
    #     where=V >= Vmin,
    #     color="tab:red",
    #     # alpha=0.9,
    #     linewidth=0,
    #     label=r"$R_{\mathrm{GL}}$",
    #     # interpolate=True,
    #     # alpha=0.3,
    #     # edgecolor="none",
    # )

    ax.text(
        t - 0.01,
        (1 / 4 + vmax(t)) / 2,
        r"$t^{\star}$",
        ha="right",
        va="center",
        fontsize=8,
    )

    ax.set(
        xlabel="$c(p)$",
        xlim=(0, 1),
        ylim=(0, None),
    )

    add_legend(ax, ncol=2)
    save_fig(fig, out, "vmin")


def vmax(c):
    return c * (1 - c)


def vmin(c, t):
    M = np.asarray(c >= t).astype(int)
    return c * (t - c) * (1 - M) + (1 - c) * (c - t) * M


def LB(C, V, t):
    Vmin = vmin(C, t)
    return np.clip(V - Vmin, 0, None)


def UB(C, V, t):
    return 0.5 * (np.sqrt(V + np.square(C - t)) - np.abs(C - t))


def test_gap(out):
    t = 0.4

    CS = np.linspace(0, 1, 1000)
    VS = np.linspace(0, 1 / 4, 1000)

    CC, VV = np.meshgrid(CS, VS)

    Z = np.full((CS.shape[0], VS.shape[0]), np.nan)

    for (i, C), (j, V) in itertools.product(enumerate(CS), enumerate(VS)):
        if V > C * (1 - C):
            continue

        Z[i, j] = UB(C, V, t) - LB(C, V, t)

    set_latex_font()
    fig, ax = plt.subplots(figsize=(3, 2))

    # cm = plt.cm.RdBu_r
    crf = ax.contourf(CC, VV, Z.T, levels=20, cmap="Blues_r")
    cbar = fig.colorbar(crf, ax=ax)

    Vmax = vmax(CS)

    # ax.plot(CS, Vmax, label=r"$V_{\mathrm{max}}$", color="black", ls="-", lw=1)
    ax.set(
        xlabel="$C$",
        ylabel="$V$",
        # xlim=(0, 1),
        # ylim=(0, None),
        title=r"$\mathrm{UB} - \mathrm{LB}$",
    )

    # add_legend(ax)

    save_fig(fig, out)


@pytest.mark.parametrize("C", [0.1, 0.2, 0.4])
def test_bounds(out, C):
    # C = 0.25
    t = 0.4

    VV = np.linspace(0, vmax(C), 100)

    lb = LB(C, VV, t)
    ub = UB(C, VV, t)

    set_latex_font()
    fig, ax = plt.subplots(figsize=(2.5, 2))

    Vmin = vmin(C, t)
    Vmax = vmax(C)

    ax.axvline(Vmin, color="black", ls="--", lw=0.5)  # , label=r"$V_{\mathrm{min}}$")
    ax.axvline(Vmax, color="black", ls="-.", lw=0.5)  # , label=r"$V_{\mathrm{max}}$")

    ax.plot(VV, lb, label="LB")
    ax.plot(VV, ub, label="UB")

    # ax.fill_between(
    #     VV,
    #     lb,
    #     ub,
    #     facecolor="tab:gray",
    #     zorder=0,
    #     alpha=0.2,
    #     label=r"$R_{\mathrm{GL}}$",
    # )

    ax.text(
        Vmin + 0.002,
        3 / 16,
        # (UB(C, Vmin, t) + UB(C, Vmax, t)) / 2,
        r"$V_{\mathrm{min}}$",
        ha="left",
        va="center",
        fontsize=8,
    )
    if Vmax < 0.2:
        ax.text(
            Vmax + 0.002,
            3 / 16,
            # LB(C, Vmax, t) / 2,
            r"$V_{\mathrm{max}}$",
            ha="left",
            va="center",
            fontsize=8,
        )
    else:
        ax.text(
            Vmax - 0.002,
            3 / 16,
            # LB(C, Vmax, t) / 2,
            r"$V_{\mathrm{max}}$",
            ha="right",
            va="center",
            fontsize=8,
        )

    ax.set(
        xlabel=r"$\mathbb{V}[F|H]$",
        # ylabel=r"$t^{\star}R_{\mathrm{GL}}$",
        xlim=(-0.01, 0.25),
        ylim=(-0.01, 0.25),
        # ylim=(0, None),
    )

    add_legend(ax, ncol=3)

    save_fig(fig, out, c=C, t=t)


@pytest.mark.parametrize("case", ["lb", "ub"])
def test_lb_ub(out, case):
    t = 0.4

    CS = np.linspace(0, 1, 1000)
    VS = np.linspace(0, 1 / 4, 1000)

    CC, VV = np.meshgrid(CS, VS)

    Z = np.full((CS.shape[0], VS.shape[0]), np.nan)

    for (i, C), (j, V) in itertools.product(enumerate(CS), enumerate(VS)):
        if V > C * (1 - C):
            continue

        if case == "lb":
            Z[i, j] = LB(C, V, t)
        elif case == "ub":
            Z[i, j] = UB(C, V, t)

    set_latex_font()
    fig, ax = plt.subplots(figsize=(3, 2))

    # cm = plt.cm.RdBu_r
    crf = ax.contourf(CC, VV, Z.T, levels=20, cmap="Blues_r")
    cbar = fig.colorbar(crf, ax=ax)

    Vmax = vmax(CS)

    # ax.plot(CS, Vmax, label=r"$V_{\mathrm{max}}$", color="black", ls="-", lw=1)
    title = r"$\mathrm{LB}$" if case == 0 else r"$\mathrm{UB}$"
    ax.set(
        xlabel="$C$",
        ylabel="$V$",
        # xlim=(0, 1),
        # ylim=(0, None),
        title=title,
    )
    # add_legend(ax)
    save_fig(fig, out, bound=case)


@pytest.mark.parametrize("alpha", [1, 0.5, 0.01])
def test_ub(out, alpha):
    t = 0.4
    # alpha = 0.01

    # VV = np.linspace(0, vmax(C), 100)

    CC = np.linspace(0, 1, 1001)
    VV = alpha * vmax(CC)

    ub = UB(CC, VV, t)
    lb = LB(CC, VV, t)

    set_latex_font()
    fig, ax = plt.subplots(figsize=(2.1, 1.4))

    # ax.plot(CC, np.abs(CC - t), label=r"$R_f^{\mathrm{CL}}$", color="tab:red")
    ax.plot(CC, lb, label=r"$L_f^{\mathrm{GL}}(p)$")
    ax.plot(CC, ub, label=r"$U_f^{\mathrm{GL}}(p)$")
    ax.axvline(t, color="black", lw=0.5)

    ax.text(
        t - 0.01,
        0.25 / 2,
        # ub.max() / 2,
        r"$t^{\star}$",
        ha="right",
        va="center",
        fontsize=8,
    )

    ax.set(
        xlabel=r"$c(p)$",
        # ylabel=r"$t^{\star}R_{\mathrm{GL}}$",
        xlim=(0, 1),
        ylim=(0, 0.25),
    )

    add_legend(ax, ncol=3)

    # ax.yaxis.set_label_coords(-0.12, 0.7)  # Adjust the position of the label
    # ax.set_ylabel(r"$\times\!U_{\!\!\Delta}$")

    ax.annotate(
        r"$\times\!U_{\!\!\Delta}$",
        xy=(0, 1),
        xycoords="axes fraction",
        # xytext=(-0.07, 1.07),
        xytext=(-0.07, 1.0),
        textcoords="axes fraction",
        ha="center",
        va="center",
    )

    # if alpha == 1:
    #     ax.set_ylabel("Regret")

    save_fig(fig, out, t=t, a=alpha, pad_inches=0)


# @pytest.mark.parametrize("alpha", [1, 0.5, 0.01])
def test_ub2(out):
    t = 0.4
    # alpha = 0.01

    # VV = np.linspace(0, vmax(C), 100)

    CC = np.linspace(0, 1, 1001)

    def plot_case(ax, alpha, i):
        VV = alpha * vmax(CC)

        ub = UB(CC, VV, t)
        lb = LB(CC, VV, t)

        # ax.plot(CC, np.abs(CC - t), label=r"$R_f^{\mathrm{CL}}$", color="tab:red")
        ax.plot(CC, lb, label=r"$L_f^{\mathrm{GL}}(p)$")
        ax.plot(CC, ub, label=r"$U_f^{\mathrm{GL}}(p)$")
        ax.axvline(t, color="black", lw=0.5)

        if i == 1:
            ypos = 0.25 * 3 / 4
        else:
            ypos = 0.25 / 2

        ax.text(
            t - 0.01,
            ypos,
            # ub.max() / 2,
            r"$t^{\star}$",
            ha="right",
            va="center",
            fontsize=8,
        )

        ax.set(
            xlabel=r"$c(p)$",
            # ylabel=r"$t^{\star}R_{\mathrm{GL}}$",
            xlim=(0, 1),
            ylim=(0, 0.25),
        )

        # ax.yaxis.set_label_coords(-0.12, 0.7)  # Adjust the position of the label
        # ax.set_ylabel(r"$\times\!U_{\!\!\Delta}$")

        ticks = np.linspace(0, 1, 3)
        # ticks = np.linspace(0, 1, 11)
        ax.set_xticks(ticks)
        ticklabels = [""] * len(ticks)
        ticklabels[0] = "0"
        ticklabels[-1] = "1"
        ax.set_xticklabels(ticklabels)

        if i == 0:
            ax.annotate(
                r"$\times\!U_{\!\!\Delta}$",
                xy=(0, 1),
                xycoords="axes fraction",
                # xytext=(-0.07, 1.07),
                xytext=(-0.13, 1.0),
                textcoords="axes fraction",
                ha="center",
                va="center",
                fontsize=7,
            )
        else:
            ax.set_yticklabels([])

        ax.xaxis.set_label_coords(0.5, -0.03)  # Adjust the position of the label
        # ax.minorticks_off()

    set_latex_font()
    fig, axes = plt.subplots(1, 3, figsize=(2.5, 0.7), gridspec_kw={"wspace": 0})
    ax1, ax2, ax3 = axes

    for i, ax in enumerate(axes):
        # Set positions (x0, y0, width, height) of each subplot
        ax.set_position(
            [i * (1 / 3), 0, 1 / 3 - 0.02, 1]
        )  # Adjust the position manually

    plt.rc("legend", borderpad=0.1)
    # plt.rc("legend", borderaxespad=0.01)
    # plt.rc("legend", handlelength=1.6)
    # plt.rc("legend", labelspacing=0.2)
    plt.rc("legend", handletextpad=0.4)
    plt.rc("legend", columnspacing=1.5)

    plot_case(ax1, 1, 0)
    plot_case(ax2, 0.5, 1)
    plot_case(ax3, 0.01, 2)
    add_legend(ax2, ncol=2, dy=0.2)
    ax1.set_title(r"a. $\mathrm{GL} = V_{\mathrm{max}}$", fontsize=7)
    ax2.set_title(r"b. $\mathrm{GL} = \tfrac{V_{\mathrm{max}}}{2}$", fontsize=7)
    ax3.set_title(r"c. $\mathrm{GL} = \tfrac{V_{\mathrm{max}}}{100}$", fontsize=7)
    save_fig(fig, out, t=t, pad_inches=0)
