import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

size = "large"
path = "./plots/train2-log-" + size + ".csv"
df_orig = pd.read_csv(path)
# df_orig = pd.read_csv("./plots/train2-log-medium.csv")

df = df_orig.copy()

sns.set()

# Get the original Seaborn color palette
original_palette = sns.color_palette("colorblind")

# Shuffle the colors
myorder = [2, 3, 1, 0]
# myorder = [2, 1, 0]
shuffled_palette = [original_palette[i] for i in myorder]

# Set the shuffled palette
sns.set_palette(shuffled_palette)

# line_styles = {'Original': '--',  # Solid line
#               'EMA': ':',      # Dashed line
#               'SWA': '-.',      # Dash-dot line
#               'LAWA (Ours)': '-'}
# df.plot(x="step", y=["Original", "EMA", "SWA", "LAWA (Ours)"])#, style=line_styles)
# df_sub = df[df["step"]>=52000]
df.loc[df['step'] < 50000, 'SWA'] = None
df.plot(x="step",
        y=["Original", "EMA", "SWA", "LAWA (Ours)"])  # , color = ['blue', 'orange', 'red'])#, style=line_styles)
# plt.plot(df_sub["step"], df_sub["SWA"], label="SWA", color="green")
plt.ylim(2.75, 3.3)
# plt.ylim(2.85, 3.25)
# plt.ylim(2.85, 3.5)
plt.ylabel("Validation Loss")
plt.xlabel("Training steps (K)")
plt.title("(a) GPT-2 " + size + " (355M)")

plt.legend()
'''
handles, labels = plt.gca().get_legend_handles_labels()
order = [1, 2, 0, 3]
ordered_handles = [handles[idx] for idx in order]
ordered_labels = [labels[idx] for idx in order]
plt.legend(ordered_handles, ordered_labels)
'''
# Create the legend with the ordered handles and labels
# plt.legend(ordered_handles, ordered_labels)

plt.axvline(x=50000, color='r', linestyle='--', label='Vertical Line at x=3')

# Annotate the vertical line with text
plt.text(49000, 3.1, '75% training', fontsize=12, ha='center', rotation=90)

path = "./plt_pdfs/GPT2-" + size + "-mainfig.pdf"
plt.savefig(path, dpi=600, bbox_inches="tight")
plt.show()
print("Saved: ", path)
