import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.lines import Line2D
import matplotlib as mpl
import matplotlib.gridspec as gridspec
from c_measurements import SAMPLERS, sampler_colors, n_markers, IMAGE_WIDTH, GOLDEN_RATIO


data = pd.read_csv("data/c_measurements.csv")

plt.style.use("paper.mplstyle")
mpl.rcParams["axes.spines.right"] = True
mpl.rcParams["axes.spines.top"] = True
mpl.rcParams["ytick.right"] = True
mpl.rcParams["xtick.top"] = True
mpl.rcParams["xtick.minor.visible"] = True
mpl.rcParams["ytick.minor.visible"] = True

subset = data[[
    "dist_file", "sampler",
    "preprocess_time_warm", "sample_time",
    "preprocess_warm-energy-pkg-0", "sample-energy-pkg-0",
    "N", "Z", "H"
]]
C = subset.pivot(index=["dist_file", "N", "Z", "H"], columns="sampler")
C.columns = [f"{col[0]}_{col[1]}" for col in C.columns]
C = C.reset_index()
C["break_even_time"] = (
    (C["preprocess_time_warm_Alias method"] - C["preprocess_time_warm_cLUT"])
    / (C["sample_time_cLUT"] - C["sample_time_Alias method"])
)
C["break_even_energy_pkg"] = (
    (C["preprocess_warm-energy-pkg-0_Alias method"] - C["preprocess_warm-energy-pkg-0_cLUT"])
    / (C["sample-energy-pkg-0_cLUT"] - C["sample-energy-pkg-0_Alias method"])
)


data = pd.read_csv("data/python_results.csv")
subset = data[[
    "dist", "sampler",
    "preprocessing_time", "sampling_time",
    "N", "Z", "entropy", "samples_generated"
]]
python = subset.pivot(index=["dist", "N", "Z", "entropy", "samples_generated"], columns="sampler")
python.columns = [f"{col[0]}_{col[1]}" for col in python.columns]
python = python.reset_index()
python["break_even_time"] = (
    (python["preprocessing_time_NumPy"] - python["preprocessing_time_cLUT"])
    / (python["sampling_time_cLUT"] - python["sampling_time_NumPy"])
)

fig= plt.figure(figsize=(2 * IMAGE_WIDTH, 0.8 * IMAGE_WIDTH / GOLDEN_RATIO))
gs = gridspec.GridSpec(1, 3, width_ratios=[1.2, 1, 1.6], wspace=0.35)

ax1 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[0, 1])
ax3 = fig.add_subplot(gs[0, 2])
axs = [ax1, ax3, ax2]

# ----> PEAK MEMORY <-----
memory = pd.read_csv("data/memory.csv")
for _, row in memory.iterrows():
    axs[0].scatter(
        row["N"],
        row['memory'],
        color=sampler_colors[row["sampler"]],
        marker=n_markers[row["sampler"]],
        s=10,
        label=f"{row['sampler']}",
        edgecolors="black",
        linewidth=0.1,
    )

axs[0].set_xlabel("Number of outcomes " + r"$n$")
axs[0].set_ylabel("Memory [bytes]")
axs[0].grid(True, which="both", alpha=0.25)
axs[0].set_xscale("log", base=10)
axs[0].set_yscale("log", base=10)
axs[0].set_xticks([1e4, 1e6, 1e8])
axs[0].minorticks_on()
axs[0].tick_params(axis='y', rotation=90, pad=-1)


# ----> BREAK EVEN ANALYSIS <----
Colors = [sampler_colors[c] for c in SAMPLERS]
Markers = [n_markers[c] for c in SAMPLERS]

# SAMPLE TIME (C)
axs[1].grid(True, which="both", alpha=0.25)
for _, row in C.iterrows():
    axs[1].scatter(
        row["N"],
        row['break_even_time'],
        color=Colors[0],
        edgecolors="black",
        marker=Markers[0],
        s=10,
        linewidth=0.1,
        label='sample time (C)',
    )

axs[1].set_xscale("log", base=10)
axs[1].set_yscale("log", base=10)
axs[1].minorticks_on()
axs[1].set_xlabel("Number of outcomes " + r"$n$")
axs[1].set_ylabel("Break even point " + r"$n^*$")
axs[1].set_xticks([1e4, 1e6, 1e8])
axs[1].tick_params(axis='y', rotation=90, pad=-1)

# ENERGY (C)
for _, row in C.iterrows():
    axs[1].scatter(
        row["N"],
        row['break_even_energy_pkg'],
        color=Colors[1],
        marker=Markers[1],
        edgecolors="black",
        s=10,
        linewidth=0.1,
        label='energy (C)',
    )

# SAMPLE TIME (PYTHON)
data = pd.read_csv("data/python_results.csv")
subset = data[[
    "dist", "sampler",
    "preprocessing_time", "sampling_time",
    "N", "Z", "entropy", "samples_generated"
]]
python = subset.pivot(index=["dist", "N", "Z", "entropy", "samples_generated"], columns="sampler")
python.columns = [f"{col[0]}_{col[1]}" for col in python.columns]
python = python.reset_index()
python["break_even_time"] = (
    (python["preprocessing_time_NumPy"] - python["preprocessing_time_cLUT"])
    / (python["sampling_time_cLUT"] - python["sampling_time_NumPy"])
) * python["samples_generated"]

for _, row in python.iterrows():
    axs[1].scatter(
        row["N"],
        row['break_even_time'],
        color=Colors[2],
        marker=Markers[2],
        s=10,
        edgecolors="black",
        linewidth=0.1,
    )
# ----> cLUT TABLE MEMORY <-----
memory = pd.read_csv("data/c_measurements.csv").query("sampler == 'cLUT'")

entropy = memory["H"]
cmap = plt.cm.Blues_r
norm = mpl.colors.Normalize(vmin=entropy.min(), vmax=entropy.max())

sc = axs[2].scatter(
    memory["N"],
    (memory["cLUT_r"].astype(int) + 1) * 2 ** memory["cLUT_c"].astype(int),
    c=entropy,
    cmap=cmap,
    norm=norm,
    marker=n_markers["cLUT"],
    s=10,
    edgecolors="black",
    linewidth=0.1,
)

cbar = plt.colorbar(sc, ax=axs[2], orientation="horizontal", location="top", anchor=(0.5, 1.2))
cbar.set_label("Entropy (H)")

axs[2].set_xlabel("Number of outcomes " + r"$n$")
axs[2].set_ylabel("Memory [bytes]")
axs[2].grid(True, which="both", alpha=0.25)
axs[2].set_xscale("log", base=10)
axs[2].set_yscale("log", base=10)
axs[2].set_xticks([1e4, 1e6, 1e8])
axs[2].minorticks_on()
axs[2].tick_params(axis='y', rotation=90, pad=-1)


labels = [ "Time (C)", "Time (Python)", "Energy (C)"]
markers = [Markers[n] for n in [0, 2, 1]]
colors = [Colors[n] for n in [0, 2, 1]]
legend_handles = [
    Line2D(
        [0],
        [0],
        marker=mk,
        color="none",
        markerfacecolor=col,
        markeredgecolor="black",
        linestyle="",
        markersize=4,
        linewidth=0.05,
        )
    for mk, col in zip(markers, colors)
    ]

fig.legend(
    handles=legend_handles,
    labels=labels,
    loc="upper left",
    ncol=2,
    frameon=False,
    bbox_to_anchor=(0.55, 1.2),
    handletextpad=0.0,
    columnspacing=0.1,
    title=None,
    )

labels = list(SAMPLERS)
print(SAMPLERS)
markers = [n_markers[x] for x in labels]
colors = [sampler_colors[x] for x in labels]
legend_handles = [
    Line2D(
        [0],
        [0],
        marker=mk,
        color="none",
        markerfacecolor=col,
        markeredgecolor="black",
        linestyle="",
        markersize=4,
        linewidth=0.05,
        )
    for mk, col in zip(markers, colors)
    ]

fig.legend(
    handles=legend_handles,
    labels=labels,
    loc="upper left",
    ncol=2,
    frameon=False,
    bbox_to_anchor=(0.05, 1.2),
    handletextpad=0.0,
    columnspacing=0.1,
    title=None,
    )

plt.tight_layout()
plt.savefig("figures/pdf/pdf_memory_break_even.pdf", format="pdf", dpi=600, bbox_inches="tight")
