import matplotlib.pyplot as plt
import os
import json

model = 'llama2'
# Generate data for two plots
x1 = list(range(1, 51))
y1 = []
x2 = list(range(1, 51))
y2 = []

# Create the plots side by side
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

output_dir = f'ckpt/gsm8k-0805/{model}-0_shot/eval_results'
    # Accuracy
ids = []
for item in os.listdir(output_dir):
    if os.path.isdir(os.path.join(output_dir, item)):
        ids.append(item.split('-')[1])

ids = sorted(map(int, ids))

for ckpt_id in ids:
    pred_res = json.load(open(os.path.join(output_dir, f'checkpoint-{ckpt_id}', 'predict_results.json')))
    acc = pred_res['predict_accuracy']
    y1.append(acc)

loss_file = f'ckpt/gsm8k-0805/{model}-0_shot/trainer_log.jsonl'
for line in open(loss_file, 'r').readlines()[:50]:
    loss = json.loads(line)['loss']
    y2.append(loss)

ax1.plot(x1, y1, marker='o', linestyle='-', color='b')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Accuracy')
ax1.grid(True)

# Plot 2
ax2.plot(x2, y2, marker='o', linestyle='-', color='r')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Loss')
ax2.grid(True)

if model == 'llama2':
    title = 'Llama 2'
else:
    title = 'Llama 3'

fig.suptitle(title, fontsize=16)

# Display the plots
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'acc_loss.png'))