from functools import partial
from matplotlib import pyplot as plt
import seaborn as sns
from experiments.fns import get_project_data
from experiments.fns import rename_text_vec

runs = get_project_data(project="lr_land", steps=[-1])
runs = runs[runs["lr"] != 1e-3]
runs = runs[runs["width"] != 16]
exprs = ["(0.8|0.2|0|0.2|0.8|0|0.33)"]
# exprs += ["(0.5|0|0.5|0|0.5|0.5|0)"]
exprs += ["(0.33|0.33|0.34|0.33|0.33|0.34|0.33)"]
exprs += ["(0.5|0.5|0|0.33|0.33|0.34|0)"]
# exprs += ["(0.7|0|0.3|0|0.7|0.3|0)"]
runs = runs[runs["expr"].isin(exprs)]
runs["expr"] = runs.apply(partial(rename_text_vec, key="expr"), axis=1)

ours_df = runs[~runs['use_wrong_mult']]
naive_df = runs[runs['use_wrong_mult']]
num_struct = len(ours_df["expr"].unique())

fig, axes = plt.subplots(2, num_struct, figsize=(num_struct * 6, 8), dpi=100)
axes = axes.reshape(2, -1)
ylims = None
fontsize = 25
for i, struct in enumerate(ours_df["expr"].unique()):
    ax = axes[0, i]
    sns.lineplot(data=naive_df[naive_df["expr"] == struct], x='lr', y='test_error', hue='width', ax=ax, legend=False,
                 linewidth=3.0)
    ax.set_title(f'{struct}', fontsize=fontsize, pad=20)
    ax.tick_params(axis='both', which='major', labelsize=fontsize)
    ax.tick_params(axis='both', which='minor', labelsize=fontsize)
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_yticks([30, 40, 50, 60])
    ax.set_yticklabels([30, 40, 50, 60])
    for x in [35, 40, 45, 50, 55]:
        ax.axhline(y=x, color='grey', linestyle='-.', linewidth=2, alpha=0.7)
    if i == 0:
        ax.set_ylabel('Test Error (Naive)', fontsize=fontsize)
    else:
        ax.set_ylabel('')
    ax.set_xlabel('')
    ax.set_ylim(ylims)

    ax = axes[1, i]
    sns.lineplot(data=ours_df[ours_df["expr"] == struct], x='lr', y='test_error', hue='width', ax=ax, linewidth=3.0)
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_yticks([30, 40, 50, 60])
    ax.set_yticklabels([30, 40, 50, 60])
    ax.set_ylim([30, 60])
    for x in [35, 40, 45, 50, 55]:
        ax.axhline(y=x, color='grey', linestyle='-.', linewidth=2, alpha=0.7)
    ax.axvline(x=0.003, color='grey', linestyle='-.', linewidth=2, alpha=0.7)
    ax.tick_params(axis='both', which='major', labelsize=fontsize)
    ax.tick_params(axis='both', which='minor', labelsize=fontsize)
    if i == 0:
        ax.set_ylabel(r'Test Error ($\mu$P)', fontsize=fontsize)
    else:
        ax.set_ylabel('')
    ax.set_xlabel('Learning Rate', fontsize=fontsize)
    handles, labels = ax.get_legend_handles_labels()
    ax.get_legend().remove()

plt.tight_layout()
plt.savefig('./figures/lr_landscape.pdf')
plt.show()

legend_fig = plt.figure(figsize=(8, 1))
legend_ax = legend_fig.add_subplot(111)
legend_ax.axis('off')
legend_ax.legend(handles=handles, labels=labels, loc='center', ncol=len(labels))
plt.tight_layout()
plt.savefig('./figures/lr_landscape_legend.pdf', bbox_inches='tight')
