import matplotlib.pyplot as plt
import seaborn as sns
from hip.utils import setup_seaborn
import numpy as np
import os
import pandas

setup_seaborn(axis_below=True)

data = [
5.841, 5.699, 5.665,
5.824, 5.669, 5.666,
5.826, 5.686, 5.667,
5.785, 5.616, 5.590,
5.753, 5.578, 5.542,
5.751, 5.577, 5.540,
5.717, 5.532, 5.482,
5.692, 5.488, 5.409,
5.671, 5.459, 5.373,
5.635, 5.402, 5.312,
5.600, 5.355, 5.257,
5.597, 5.351, 5.254,
]

data = np.array(data)
data = data.reshape([-1, 3])
data = data.T

xs = [0,1,2,3,4,5,10,15,20,25,30,32,]

plt.figure(figsize=(3.2, 2.0))
for i, label in enumerate(['$T$=4k', '$T$=8k', '$T$=12k']):
    sns.lineplot(x=xs, y=data[i], label=label, marker='o', zorder=10)

plt.annotate('$l_d=0$\nFull HiP', (0.5, np.min(data)), fontsize=7, zorder=20)
plt.annotate('$l_d=3$\nDefault', (3.5, np.min(data) + 0.1), fontsize=7, zorder=20)
plt.annotate('$l_d=32$\nFull Dense', (31.5, np.min(data)), fontsize=7, ha='right', zorder=20)

plt.axvline(0, linestyle='--', color='#888', zorder=1)
plt.axvline(3, linestyle='--', color='#89F', zorder=1)
plt.axvline(32, linestyle='--', color='#F88', zorder=1)

plt.xlabel('# of Dense Layers')
plt.ylabel('PPL.↓')
plt.title('Perplexity on Wikitext2 / Number of Dense Layers ($l_d$)')

root = './saves/plot_ablation_ld/'
os.makedirs(root, exist_ok=True)
plt.savefig(os.path.join(root, 'plot_ablation_ld.png'), dpi=300, bbox_inches='tight', pad_inches=0.1)
plt.savefig(os.path.join(root, 'plot_ablation_ld.pdf'), bbox_inches='tight', pad_inches=0.1)
print('saved', os.path.join(root, 'plot_ablation_ld.pdf'))
print('saved', os.path.join(root, 'plot_ablation_ld.png'))