import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from matplotlib import colors as mcolors
from pathlib import Path
import numpy as np


def setup_neurips_style() -> None:
    mpl.rcParams.update(
        {
            "figure.dpi": 120,
            "font.size": 11,
            "axes.grid": True,
            "grid.linestyle": "--",
            "grid.alpha": 0.3,
            "grid.linewidth": 0.5,
            "savefig.bbox": "tight",
        }
    )


setup_neurips_style()

COLOR_CPPI = "#469990"
COLOR_FAB = "#000075"
COLOR_PPIPP = "#4363d8"
COLOR_PPI = "#e6194B"
COLOR_CLASS = "#a9a9a9"


def beautify_axis(ax, colors):
    ax.grid(True, linestyle="--", alpha=0.3, linewidth=0.5)
    ax.spines["left"].set_linewidth(0.8)
    ax.spines["bottom"].set_linewidth(0.8)
    ax.set_prop_cycle(color=colors)
    ax.tick_params(direction="out")


def plot_conf_int(ax, means, intervals, true_theta, title, ylabels, colors):
    for i, (mu, (lower, upper)) in enumerate(zip(means, intervals)):
        ax.errorbar(
            mu,
            i,
            xerr=[[mu - lower], [upper - mu]],
            # fmt="o"
            markersize=6,
            capsize=4,
            ecolor=colors[i],
            markeredgecolor=colors[i],
            markerfacecolor="white",
            linewidth=1.2,
        )
        WHISKER_SIZE = 0.05
        ax.plot(
            [lower, lower],
            [i - WHISKER_SIZE, i + WHISKER_SIZE],
            color=colors[i],
            lw=1.2,
        )
        ax.plot(
            [upper, upper],
            [i - WHISKER_SIZE, i + WHISKER_SIZE],
            color=colors[i],
            lw=1.2,
        )

    ax.axvline(
        true_theta, color="k", linestyle="--", linewidth=1, label=r"$\theta^\ast$"
    )

    ax.set_title(title)
    ax.set_yticks(range(len(ylabels)))
    ax.set_yticklabels(ylabels)
    ax.set_xlabel("$\\theta$")
    beautify_axis(ax, colors)


def boxplot_widths(ax, data, colors, title, ylim=None, y_label=True):
    bp = ax.boxplot(
        data,
        vert=True,
        showfliers=False,
        patch_artist=True,
        widths=0.55,
    )
    for patch, color in zip(bp["boxes"], colors):
        patch.set(facecolor="white", edgecolor=color, linewidth=1.2)
    for median, color in zip(bp["medians"], colors):
        median.set(color=color, linewidth=1.2)
    for idx, color in enumerate(colors):
        for elem in ("whiskers", "caps"):
            bp[elem][2 * idx].set(color=color, linewidth=1.2)
            bp[elem][2 * idx + 1].set(color=color, linewidth=1.2)

    ax.tick_params(axis="x", which="both", labelbottom=False)
    ax.set_xlabel("")
    ax.set_title(title)
    if y_label:
        ax.set_ylabel("Interval Sizes")
    if ylim:
        ax.set_ylim(*ylim)
    beautify_axis(ax, colors)


def lighter(c, alpha=0.35):
    r, g, b = mcolors.to_rgb(c)
    return (1 - alpha) * r + alpha, (1 - alpha) * g + alpha, (1 - alpha) * b + alpha


df_p = pd.read_csv(Path("results/mean_estimation_phishing_dataset_numeric.csv"))
df_c = pd.read_csv(Path("results/quantile_estimation_gene_expression.csv"))
print(np.mean(df_c["PPIPP_split_coverage"]))
print(np.mean(df_c["FAB_coverage"]))
row_p, row_c = df_p.iloc[0], df_c.iloc[0]

fig_all, ax_all = plt.subplots(1, 4, figsize=(12, 3.2))
ax_all = ax_all.reshape(2, 2)

plot_conf_int(
    ax_all[0, 0],
    [
        (row_p[f"{m}_lower"] + row_p[f"{m}_upper"]) / 2
        for m in (
            "CCI",
            # "PPI",
            "PPIPP_split",
            "FAB",
            "CPPI",
        )
    ],
    [
        (row_p[f"{m}_lower"], row_p[f"{m}_upper"])
        for m in (
            "CCI",
            # "PPI",
            "PPIPP_split",
            "FAB",
            "CPPI",
        )
    ],
    row_p["true_theta"],
    title="Confidence Interval \nfor the Mean",
    ylabels=[
        "Only labelled\ndata",
        # "Vanilla PPI \n(Angelopoulos et al., 2023)",
        "PPI++ \n(Angelopoulos et al., 2023b)",
        "FAB \n(Cortinovis & Caron, 2025)",
        "Conformal\nPPI (Ours)",
    ],
    colors=[
        COLOR_CLASS,
        # COLOR_PPI,
        COLOR_PPIPP,
        COLOR_FAB,
        COLOR_CPPI,
    ],
)

plot_conf_int(
    ax_all[0, 1],
    [
        (row_c[f"{m}_lower"] + row_c[f"{m}_upper"]) / 2
        for m in (
            "CCI",
            # "PPI",
            "PPIPP_split",
            "FAB",
            "CPPI",
        )
    ],
    [
        (row_c[f"{m}_lower"], row_c[f"{m}_upper"])
        for m in (
            "CCI",
            # "PPI",
            "PPIPP_split",
            "FAB",
            "CPPI",
        )
    ],
    row_c["true_theta"],
    title="Confidence Interval \nfor the Median",
    ylabels=[
        "Only labelled\ndata",
        # "Vanilla PPI \n(Angelopoulos et al., 2023)",
        "PPI++ \n(Angelopoulos et al., 2023b)",
        "FAB \n(Cortinovis & Caron, 2025)",
        "Conformal\nPPI (Ours)",
    ],
    colors=[
        COLOR_CLASS,
        # COLOR_PPI,
        COLOR_PPIPP,
        COLOR_FAB,
        COLOR_CPPI,
    ],
)

boxplot_widths(
    ax_all[1, 0],
    df_p[
        [
            "CPPI_width",
            "FAB_width",
            "PPIPP_split_width",
            # "PPI_width",
            "CCI_width",
        ]
    ].values,
    [
        COLOR_CPPI,
        COLOR_FAB,
        COLOR_PPIPP,
        # COLOR_PPI,
        COLOR_CLASS,
    ],
    title="Interval Sizes for the Mean \n Varying Data Split",
)
boxplot_widths(
    ax_all[1, 1],
    df_c[
        [
            "CPPI_width",
            "FAB_width",
            "PPIPP_split_width",
            # "PPI_width",
            "CCI_width",
        ]
    ].values,
    [
        COLOR_CPPI,
        COLOR_FAB,
        COLOR_PPIPP,
        # COLOR_PPI,
        COLOR_CLASS,
    ],
    title="Interval Sizes for the Median \n Varying Data Split",
    y_label=False,
)

ax_all[0, 0].text(
    0.56,
    1.07 + 0.3,
    "true mean",
    rotation=90,
    fontsize=13,
    verticalalignment="center",
    alpha=0.2,
    zorder=-1,
    fontweight="bold",
)
ax_all[0, 1].text(
    4.6,
    1.3 + 0.3,
    "true median",
    rotation=90,
    fontsize=13,
    verticalalignment="center",
    alpha=0.2,
    zorder=-1,
    fontweight="bold",
)

for a in ax_all[0]:
    a.set_yticks([])
    a.set_ylabel("")

fig_all.align_labels()
fig_all.tight_layout()
fig_all.savefig("results/fig1_all.png", dpi=300)
plt.close(fig_all)

method_labels = [
    "Conformal PPI (Ours)",
    "FAB (Cortinovis & Caron, 2025)",
    "PPI++ (Angelopoulos et al., 2023b)",
    # "Vanilla PPI (Angelopoulos et al., 2023)",
    "Only labelled data",
]
handles = [
    Patch(facecolor=lighter(c), edgecolor=c, lw=1.2)
    for c in [
        COLOR_CPPI,
        COLOR_FAB,
        COLOR_PPIPP,
        # COLOR_PPI,
        COLOR_CLASS,
    ]
]

fig_leg = plt.figure(figsize=(4.6, 0.9))
fig_leg.legend(handles, method_labels, loc="center", ncol=4, frameon=False, fontsize=14)

fig_leg.savefig("results/fig1_legend.png", dpi=300, bbox_inches="tight")
plt.close(fig_leg)
