import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

df = pd.read_csv("./logs/act.csv")

sns.set(style="whitegrid", font_scale=3.0, rc={"lines.linewidth": 5.0})
pal = sns.color_palette("Set2", as_cmap=True)
pal = [pal(i) for i in range(5)]
pal = ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3', '#ff7f00', '#ffff33', '#a65628']
sns.set_palette(sns.color_palette(pal))

x = df["width"].unique()
case_map = {'btt': "BTT", 'kron': "Kron", 'low_rank': "Low Rank", 'dense': "Dense", 'monarch': 'Monarch'}
color_map = {"BTT": pal[0], "Dense": pal[1], "Kron": pal[0], "Low Rank": pal[0], 'Monarch': pal[0]}
key = "use_wrong_mult"
cases, naive = {}, {}
for struct in list(df['struct'].unique()):
    mask = df['struct'] == struct
    dff = df[mask]
    h_ours = dff[~dff[key]]["dh_avg_2"].values
    cases[case_map[struct]] = h_ours
    if struct != "dense":
        h_naive = dff[dff[key]]["dh_avg_2"].values
    else:
        h_naive = h_ours
    naive[case_map[struct]] = h_naive

# figsize = (8, 5)
figsize = (14, 7)
dpi = 100

# fig, axs = plt.subplots(nrows=1, ncols=4, sharex="all", sharey="all", figsize=figsize, dpi=dpi)
fig, axs = plt.subplots(nrows=1, ncols=4, sharex="all", sharey="all", figsize=figsize, dpi=dpi)
labels = ["BTT", "Kron", "Monarch", "Low Rank"]

for idx, struct in enumerate(labels):
    hx = cases[struct]
    label = struct
    axs[idx].set_title(f"{struct}")
    axs[idx].scatter(x, hx, color=color_map[label], lw=4)
    axs[idx].plot(x, hx, label="ours", linestyle="solid", color=color_map[label])
    hx = naive[struct]
    axs[idx].scatter(x, hx, color=color_map[label], lw=4)
    axs[idx].plot(x, hx, linestyle="dashed", color=color_map[label], label="naive")
    label = "Dense"
    hx = cases[label]
    axs[idx].scatter(x, hx, color=color_map[label], lw=4)
    axs[idx].plot(x, hx, linestyle="solid", color=color_map[label], label="dense")
    handles, labels = axs[idx].get_legend_handles_labels()
axs[0].set_ylabel(r'$\Delta h$')

# for ax in axs:
#     ax.set_xlabel('Width')
plt.ylim([0.001, 0.15])
plt.xscale('log')
plt.yscale('log')
# plt.xlabel("Width")

fig.text(0.5, 0.01, "Width")
# axs[0].legend(handles=[solid_line, dashed_line], loc='upper left')
plt.tight_layout()
# plt.tight_layout(rect=[0.0, -0.1, 1, 1.])
plt.savefig('./figures/width_vs_dh.pdf')
plt.show()

legend_fig = plt.figure(figsize=(14, 1))
legend_ax = legend_fig.add_subplot(111)
legend_ax.axis('off')
legend_ax.legend(handles=handles, labels=labels, loc='center', ncol=len(labels))
plt.tight_layout()
plt.savefig('./figures/width_vs_dh_legend.pdf', bbox_inches='tight')
plt.show()
