import os
import pickle
import matplotlib.pyplot as plt
import numpy as np

method = 'dekr'
dataset = 'coco'
save_path = f"output/{dataset}/{method}_{dataset}"
with open(os.path.join(save_path, 'loss_acc.pkl'), "rb") as f:
    train_l1_losses, train_order_losses, train_acc, test_l1_losses, test_order_losses, test_acc = pickle.load(f)

epochs = len(train_l1_losses)

fig, ax1 = plt.subplots(figsize=(6, 5))
#plt.title(f'{method}, test acc. {test_acc[0]* 100:.2f}% to {test_acc[-1]* 100:.2f}%')

ax1.set_xlabel('Epoch', size=10)
ax1.set_ylabel('Rank Loss')
ax1.plot([i+1 for i in range(epochs)], train_order_losses, ls='-', marker='*', ms=4, color='red', label='train rank loss')
ax1.plot([i+1 for i in range(epochs)], test_order_losses, ls='--', marker='o', ms=4, color='blue', label='test rank loss')
#plt.legend(['train rank loss', 'test rank loss'], prop={'size': 10}, loc='upper right')

ax2 = ax1.twinx()
ax2.set_ylabel('Rank Acc.')
ax2.plot([i+1 for i in range(epochs)], train_acc, ls='-', marker='*', ms=4, color='green', label='train rank acc.')
ax2.plot([i+1 for i in range(epochs)], test_acc, ls='--', marker='o', ms=4, color='purple', label='test rank acc.')


plt.grid(True)
#lines = [line1, line2]
#labels = [line.get_label() for line in lines]

lines1, labels1 = ax1.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
lines = lines1 + lines2
labels = labels1 + labels2
ax1.legend(lines, labels, prop={'size': 20}, loc='center right')
#plt.legend(['train rank loss', 'test rank loss', 'train rank acc.', 'test rank acc.'], prop={'size': 10}, loc='upper right')

# plt.show()
fig.savefig(os.path.join(save_path, f'{method}_coco.png'), dpi=fig.dpi, bbox_inches='tight')