from matplotlib import pyplot as plt
import pandas as pd
import seaborn as sns

runs = pd.read_csv("./logs/lr_land.csv")
runs = runs[runs['struct'] != 'dense']
ours_df = runs[~runs['use_wrong_mult']]
naive_df = runs[runs['use_wrong_mult']]
# ours_df = pd.read_csv("./logs/lr_land_ours.csv")
# naive_df = pd.read_csv("./logs/lr_land_naive.csv")

ours_df['struct'] = ours_df['struct'].apply(lambda x: x.capitalize())
naive_df['struct'] = naive_df['struct'].apply(lambda x: x.capitalize())
num_struct = len(ours_df['struct'].unique())
# figsize = (14, 1)
fig, axes = plt.subplots(2, num_struct, sharex="all", sharey="all", figsize=(num_struct * 4.5, 12), dpi=100)
# fig, axes = plt.subplots(2, num_struct, sharex="all", sharey="all", figsize=figsize, dpi=100)
axes = axes.reshape(2, -1)
# on the first row, plot naive LR landscape, x = lr, y = test_error, hue = width
ylims = [30, 70]
xlims = [1e-4, 5e-1]
xticks = [1e-3, 1e-2, 1e-1]
yticks = [3e1, 4e1, 5e1, 6e1, 7e1]
struct_map = {"Btt": "BTT", "Kron": "Kron", "Low_rank": "Low Rank", "Monarch": "Monarch"}
sns.set(style="whitegrid", font_scale=3, rc={"lines.linewidth": 4.0})
for i, struct in enumerate(ours_df['struct'].unique()):
    ax = axes[0, i]
    sns.lineplot(data=naive_df[naive_df['struct'] == struct], x='lr', y='test_error', hue='width', ax=ax, legend=False)

    ax.set_title(f'{struct_map[struct]}')
    ax.set_xscale('log')
    ax.set_yscale('log')
    if i == 0:
        ax.set_ylabel('Naive')
    else:
        ax.set_ylabel('')
    # ax.set_xlabel('Learning Rate')
    ax.set_ylim(ylims)
    ax.set_xlim(xlims)
    ax.set_xticks(xticks)
    ax.set_xticklabels(["1", "", "2"])
    ax.set_yticks(yticks)
    # ax.set_yticklabels(["30", "50", "70"])
    ax.axvline(x=3e-3, color='gray', linestyle='--')

    ax = axes[1, i]
    sns.lineplot(data=ours_df[ours_df['struct'] == struct], x='lr', y='test_error', hue='width', ax=ax)
    # plot dense in black
    # sns.lineplot(data=dense_df, x='lr', y='test_error', color='black', ax=ax, legend=False)
    ax.set_xscale('log')
    ax.set_yscale('log')
    if i == 0:
        ax.set_ylabel('Struct-Aware')
    else:
        ax.set_ylabel('')
    ax.axvline(x=3e-3, color='gray', linestyle='--')
    ax.set_xlabel('')
    ax.set_ylim(ylims)
    ax.set_xlim(xlims)
    ax.set_xticklabels([r"$10^{-3}$", "", r"$10^{-1}$"])
    ax.set_xticks(xticks)
    ax.set_yticks(yticks)
    ax.set_yticklabels(["", "40", "50", "60", ""])
    handles, labels = ax.get_legend_handles_labels()
    ax.get_legend().remove()

fig.text(0.45, 0.005, "Learning Rate")
fig.text(0.00, 0.45, "Test Error", rotation=90)
plt.tight_layout()
plt.savefig('./figures/lr_landscape.pdf')
plt.show()

legend_fig = plt.figure(figsize=(8, 1))
legend_ax = legend_fig.add_subplot(111)
legend_ax.axis('off')
legend_ax.legend(handles=handles, labels=labels, loc='center', ncol=len(labels))
plt.tight_layout()
plt.savefig('./figures/lr_landscape_legend.pdf', bbox_inches='tight')
