import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import matplotlib as mpl
import matplotlib.gridspec as gridspec

plt.style.use("paper.mplstyle")
mpl.rcParams["axes.spines.right"] = True
mpl.rcParams["axes.spines.top"] = True
mpl.rcParams["ytick.right"] = True
mpl.rcParams["xtick.top"] = True


data = pd.read_csv("data/python_results.csv")
SAMPLERS = ["cLUT", "NumPy", "JAX", "PyTorch"]
sampler_colors = {
    sampler: color
    for sampler, color in zip(
        SAMPLERS, ["#1e90ff", "#e16f00", "#10a674", "#E34234", "#E40078"]
    )
}

n_markers = {n: marker for n, marker in zip(SAMPLERS, ["P", "s", "^", "o", "*", "X"])}


PT = 1.0 / 72.27  # 72.27 points to an inch.
GOLDEN_RATIO = (1 + 5**0.5) / 2
IMAGE_WIDTH = 397.48499 * PT
IMAGE_WIDTH *= 1.25

# ------> SANPLING TIME <------
if False:
    fig, axs = plt.subplots(1, 1, figsize=(IMAGE_WIDTH, IMAGE_WIDTH / GOLDEN_RATIO))

    plt.grid(True, which="both", alpha=0.25)

    for _, row in data.iterrows():
        plt.scatter(
            row["N"],
            row["sampling_time"],
            color=sampler_colors[row["sampler"]],
        marker=n_markers[row["sampler"]],
        s=10,
        label=f"{row['sampler']}",
        edgecolors="black",
        linewidth=0.1,
    )

    handles, labels = plt.gca().get_legend_handles_labels()
    unique = dict(zip(labels, handles))
    plt.xscale("log", base=10)
    plt.yscale("log", base=10)
    plt.xlabel("Number of outcomes " + r"$n$")
    plt.ylabel("Sampling time [s]")

    labels = SAMPLERS
    markers = [n_markers[x] for x in labels]
    colors = [sampler_colors[x] for x in labels]
    legend_handles = [
        Line2D(
            [0],
            [0],
            marker=mk,
            color="none",
            markerfacecolor=col,
            markeredgecolor="black",
            linestyle="",
            markersize=4,
            linewidth=0.05,
        )
        for mk, col in zip(markers, colors)
    ]

    fig.legend(
        handles=legend_handles,
        labels=labels,
        loc="upper center",
        ncol=4,
        frameon=False,
        bbox_to_anchor=(0.5, 1.05),
        handletextpad=0.0,
        columnspacing=0.2,
        title=None,
    )

    plt.tight_layout(pad=0.5)
    plt.tight_layout()
    plt.savefig(
    "figures/pdf/pdf_python_sampling_time_cython.pdf",
    format="pdf",
    dpi=600,
    bbox_inches="tight",
)

    plt.close()


# ------> PREPROCESSING TIME <------
    fig, axs = plt.subplots(1, 1, figsize=(IMAGE_WIDTH, IMAGE_WIDTH / GOLDEN_RATIO))

    plt.grid(True, which="both", alpha=0.25)

    for _, row in data.iterrows():
        plt.scatter(
            row["N"],
            row["preprocessing_time"],
            color=sampler_colors[row["sampler"]],
            marker=n_markers[row["sampler"]],
            s=10,
            label=f"{row['sampler']}",
            edgecolors="black",
            linewidth=0.1,
        )

    plt.xscale("log", base=10)
    plt.yscale("log", base=10)
    plt.xlabel("Number of outcomes " + r"$n$")
    plt.ylabel("Preprocessing time [s]")


    labels = SAMPLERS
    markers = [n_markers[x] for x in labels]
    colors = [sampler_colors[x] for x in labels]
    legend_handles = [
        Line2D(
            [0],
            [0],
            marker=mk,
            color="none",
            markerfacecolor=col,
            markeredgecolor="black",
            linestyle="",
            markersize=4,
            linewidth=0.05,
        )
        for mk, col in zip(markers, colors)
    ]

    fig.legend(
        handles=legend_handles,
        labels=labels,
        loc="upper center",
        ncol=4,
        frameon=False,
        bbox_to_anchor=(0.5, 1.05),
        handletextpad=0.0,
        columnspacing=0.2,
        title=None,
    )
    plt.tight_layout(pad=0.5)
    plt.tight_layout()
    plt.savefig(
        "figures/pdf/pdf_python_preprocessing_time_cython.pdf",
    format="pdf",
    dpi=600,
    bbox_inches="tight",
)

    plt.close()


# -----> COMBINED <-------
fig= plt.figure(figsize=(IMAGE_WIDTH, 0.5 * IMAGE_WIDTH / GOLDEN_RATIO))
gs = gridspec.GridSpec(1, 3, width_ratios=[2, 1, 1], wspace=0.5)

ax1 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[0, 1])
ax3 = fig.add_subplot(gs[0, 2])

axs = [ax1, ax2, ax3, ax3]

for ax in axs:
    ax.grid(True, which="both", alpha=0.25)

for _, row in data.iterrows():
    axs[0].scatter(
        row["N"],
        row["sampling_time"],
        color=sampler_colors[row["sampler"]],
        marker=n_markers[row["sampler"]],
        s=10,
        label=f"{row['sampler']}",
        edgecolors="black",
        linewidth=0.1,
    )

handles, labels = plt.gca().get_legend_handles_labels()
unique = dict(zip(labels, handles))
axs[0].set_xscale("log", base=10)
axs[0].set_yscale("log", base=10)
axs[0].set_xlabel("Number of outcomes " + r"$n$")
axs[0].set_ylabel("Sampling time [s]")
axs[0].tick_params(axis='y', rotation=90, pad=-1)

labels = SAMPLERS
markers = [n_markers[x] for x in labels]
colors = [sampler_colors[x] for x in labels]
legend_handles = [
    Line2D(
        [0],
        [0],
        marker=mk,
        color="none",
        markerfacecolor=col,
        markeredgecolor="black",
        linestyle="",
        markersize=7,
        linewidth=0.1,
    )
    for mk, col in zip(markers, colors)
]


for _, row in data.iterrows():
    axs[1].scatter(
        row["N"],
        row["preprocessing_time"],
        color=sampler_colors[row["sampler"]],
        marker=n_markers[row["sampler"]],
        s=10,
        label=f"{row['sampler']}",
        edgecolors="black",
        linewidth=0.1,
    )

axs[1].set_xscale("log", base=10)
axs[1].set_yscale("log", base=10)
axs[1].set_xlabel(r"$n$")
axs[1].set_ylabel("Preprocessing time [s]")
axs[1].tick_params(axis='y', rotation=90, pad=-1)

fig.legend(
handles=legend_handles,
labels=labels,
loc="upper center",
ncol=4,
frameon=False,
bbox_to_anchor=(0.5, 1.05),
handletextpad=0.0,
columnspacing=0.2,
title=None,
)


data = pd.read_csv("data/python_results.csv")
subset = data[[
    "dist", "sampler",
    "preprocessing_time", "sampling_time",
    "N", "Z", "entropy", "samples_generated"
]]


data = pd.read_csv("data/c_measurements.csv")
data = data[data["sampler"] == "cLUT"]
data["cLUT_r"] = data["cLUT_r"].astype(float)
SAMPLERS = data["sampler"].unique()
sampler_colors = {
    sampler: color
    for sampler, color in zip(
        SAMPLERS, ["#1e90ff", "#e16f00", "#10a674", "#E34234", "#E40078"]
    )
}

for _, row in data.iterrows():
    axs[3].scatter(
        row["H"],
        2 ** row["cLUT_r"] / (row["cLUT_r"] + 1),
        color=sampler_colors[row["sampler"]],
        marker=n_markers[row["sampler"]],
        s=10,
        label=f"{row['sampler']}",
        edgecolors="black",
        linewidth=0.1,
    )

axs[3].set_xlabel("Entropy " + r"$H(\mathbf{p})$")
axs[3].set_ylabel("Compression ratio " + r"$\rho$")
axs[3].set_yscale("log", base=10)
axs[3].tick_params(axis='y', rotation=90, pad=-1)

plt.tight_layout()
plt.savefig(
    "figures/pdf/pdf_python_combined.pdf", format="pdf", dpi=600, bbox_inches="tight"
)

plt.close()
