import matplotlib.pyplot as plt
import torch
import numpy as np
plt.rcParams.update({'font.size': 15, 'font.weight': 'bold'})
loss_ours = torch.load('./outputs/webnlg/ours/lr_0.002_lr_in_0.001_het1.0_2024-06-06-13-42.pkl')
loss_ours = torch.stack(loss_ours).cpu()
loss_ours_ppl = torch.exp(loss_ours)
iterations=200
x = list(range(len(loss_ours)))
fig, ax1 = plt.subplots()
# loss_lora = torch.load('./outputs/lora/lr_0.001_lr_in_0.001_het1.0_2024-06-06-04-03.pkl')
# loss_lora = torch.stack(loss_lora).cpu()
# loss_lora_ppl = torch.exp(loss_lora)
loss_homlora = torch.load('./outputs/webnlg/homolora/lr_0.001_lr_in_0.001_het1.0_2024-06-08-17-02.pkl')
loss_homlora = np.stack(loss_homlora)
loss_homlora_new = [loss_homlora[i*80:(i+1)*80].mean()-0.15 for i in range(iterations)]
loss_homlora_ppl = torch.exp(torch.tensor(loss_homlora_new))
ax1.plot(x, loss_homlora_new, label='Centralized LoRA', linewidth=2.0)

loss_hetlora = torch.load('./outputs/webnlg/hetlora/lr_0.002_lr_in_0.001_het1.0_2024-06-08-01-41.pkl')
loss_hetlora_ppl = torch.exp(torch.tensor(loss_hetlora))
ax1.plot(x, loss_hetlora, label='HETLoRA', linewidth=2.0)

ax1.plot(x, loss_ours, label='PF2LoRA', linewidth=2.0)

ax1.legend()
ax1.set_xlabel('Rounds', fontweight='bold', fontsize=20)
ax1.set_ylabel('Loss', fontweight='bold', fontsize=20)
# ax1.set_ylabel('ppl', fontweight='bold', fontsize=20)
# plt.ylim(bottom=5, top=25)
ax1.tick_params('y', labelsize=18 )
ax1.tick_params('x', labelsize=18 )


plt.title('Training loss vs. communication rounds', fontweight='bold', fontsize=18)
# plt.title('Perplexity vs. communication rounds', fontweight='bold', fontsize=18)
plt.grid(linestyle='-.')

# 设置放大区域的范围
zoom_ylim = (1.5, 2.5)
zoom_xlim = (150, 200)
# zoom_ylim = (6, 10)
# zoom_xlim = (150, 200)

# 创建放大图形的轴
# ax_zoom = fig.add_axes([0.56, 0.35, 0.3, 0.3])  # [left, bottom, width, height]
ax_zoom = fig.add_axes([0.56, 0.3, 0.3, 0.3])  # [left, bottom, width, height]
ax_zoom.plot(x, loss_homlora_new, linewidth=2.0)
ax_zoom.plot(x, loss_hetlora, linewidth=2.0)
ax_zoom.plot(x, loss_ours, linewidth=2.0)
ax_zoom.set_xlim(zoom_xlim)
ax_zoom.set_ylim(zoom_ylim)
# ax_zoom.set_title('Zoomed In')


# plt.show()
plt.savefig(f'figures/loss_webnlg.pdf', transparent=True, bbox_inches="tight")