import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.lines import Line2D
import matplotlib as mpl


data = pd.read_csv("data/c_measurements.csv")
data = data.reset_index()

SAMPLERS = data["sampler"].unique()
SAMPLERS = list(SAMPLERS)
print(SAMPLERS)
sampler_colors = {
    sampler: color
    for sampler, color in zip(
        SAMPLERS, ["#1e90ff", "#e16f00", "#10a674", "#E34234", "#8B008B"]
    )
}

n_markers = {n: marker for n, marker in zip(SAMPLERS, ["P", "s", "^", "o", "*", "X"])}

plt.style.use("paper.mplstyle")
mpl.rcParams["axes.spines.right"] = True
mpl.rcParams["axes.spines.top"] = True
mpl.rcParams["ytick.right"] = True
mpl.rcParams["xtick.top"] = True
mpl.rcParams["xtick.minor.visible"] = True
mpl.rcParams["ytick.minor.visible"] = True
PT = 1.0 / 72.27  # 72.27 points to an inch.
GOLDEN_RATIO = (1 + 5**0.5) / 2
IMAGE_WIDTH = 239.39438 * PT

if __name__ == "__main__":
    fig, axs = plt.subplots(1, 4, figsize=(2 * IMAGE_WIDTH, 1.5 * IMAGE_WIDTH / GOLDEN_RATIO))

    axs[0].grid(True, which="both", alpha=0.25)
    for _, row in data.iterrows():
        axs[0].scatter(
            row["N"],
            row["sample_time"],
            color=sampler_colors[row["sampler"]],
            marker=n_markers[row["sampler"]],
            s=10,
            label=f"{row['sampler']}",
            edgecolors="black",
            linewidth=0.1,
        )

    handles, labels = plt.gca().get_legend_handles_labels()
    unique = dict(zip(labels, handles))
    axs[0].set_xscale("log", base=10)
    axs[0].set_yscale("log", base=2)
    axs[0].minorticks_on()
    axs[0].set_xlabel("Number of outcomes " + r"$n$")
    axs[0].set_ylabel("Sampling time [s]")
    axs[0].set_xticks([1e4, 1e6, 1e8])
    axs[0].set_yticks([2**p for p in [-24, -25, -26, -27]])

    axs[1].grid(True, which="both", alpha=0.25)
    for _, row in data.iterrows():
        axs[1].scatter(
            row["N"],
            row["preprocess_time_warm"],
            color=sampler_colors[row["sampler"]],
            marker=n_markers[row["sampler"]],
            s=10,
            label=f"{row['sampler']}",
            edgecolors="black",
            linewidth=0.1,
        )

    axs[1].set_xscale("log", base=10)
    axs[1].set_yscale("log", base=2)
    axs[1].set_xlabel("Number of outcomes " + r"$n$")
    axs[1].set_ylabel("Preprocessing time [s]")
    axs[1].set_xticks([1e4, 1e6, 1e8])

    axs[2].grid(True, which="both", alpha=0.25)
    for _, row in data.iterrows():
        axs[2].scatter(
            row["N"],
            row["sample-energy-pkg-0"],
            color=sampler_colors[row["sampler"]],
            marker=n_markers[row["sampler"]],
            s=10,
            label=f"{row['sampler']}",
            edgecolors="black",
            linewidth=0.1,
        )

    axs[2].set_xscale("log", base=10)
    axs[2].set_yscale("log", base=2)
    axs[2].set_xlabel("Number of outcomes " + r"$n$")
    axs[2].set_ylabel("Sampling pkg energy [J]")
    axs[2].set_xticks([1e4, 1e6, 1e8])
    axs[2].grid(which="minor", color="#EEEEEE", linestyle=":", linewidth=0.5)
    axs[2].minorticks_on()

# --- plot f(H)=H and fill to H+2 ---
    H_vals = np.linspace(data["H"].min(), data["H"].max(), 400)
    fH = H_vals
    axs[3].plot(H_vals, fH, color="black", lw=1.5)
    axs[3].fill_between(H_vals, fH, fH + 2, facecolor="gray", alpha=0.3)

    axs[3].grid(True, which="both", alpha=0.25)
    for _, row in data.iterrows():
        axs[3].scatter(
            row["H"],
            row["sample_bits"],
            color=sampler_colors[row["sampler"]],
            marker=n_markers[row["sampler"]],
            s=10,
            label=f"{row['sampler']}",
            edgecolors="black",
            linewidth=0.1,
        )
    axs[3].set_xlabel("Entropy " + r"$H(\mathbf{p})$")
    axs[3].set_ylabel("Average consumed bits")

    axs[3].text(
        0.95,
        0.12,
        "theoretic",
        color="black",
        transform=axs[3].transAxes,
        ha="right",
        va="bottom",
    )
    axs[3].text(
        0.95,
        0.05,
        "minimum",
        color="black",
        transform=axs[3].transAxes,
        ha="right",
        va="bottom",
    )

    labels = sampler_colors.keys()
    markers = [n_markers[x] for x in labels]
    colors = [sampler_colors[x] for x in labels]
    legend_handles = [
        Line2D(
            [0],
            [0],
            marker=mk,
            color="none",
            markerfacecolor=col,
            markeredgecolor="black",
            linestyle="",
            markersize=8,
        )
        for mk, col in zip(markers, colors)
    ]

    fig.legend(
        handles=legend_handles,
        labels=labels,
        loc="upper center",
        ncol=5,
        frameon=False,
        bbox_to_anchor=(0.5, 1.05),
        handletextpad=0.5,
        columnspacing=1.2,
        title=None,
    )

    plt.tight_layout(pad=0.5)
    plt.tight_layout()
    plt.savefig("figures/pdf/pdf_c_combined.pdf", format="pdf", dpi=600, bbox_inches="tight")

