import os
import pickle
import torch
import matplotlib.pyplot as plt
import numpy as np

results_file = 'results/loss/'

def view_val_loss(number):
    loss_file = f'loss{number}.pkl'
    loss_path = os.path.join(results_file, loss_file)
    with open(loss_path, 'rb') as file:
        loss = torch.load(file, map_location='cpu')
    length = len(loss)
    print(length)
    epochs = np.zeros(length)
    val_loss = np.zeros(length)
    train_loss = np.zeros(length)
    for i in range(length):
        epochs[i] = (loss[i][0])
        val_loss[i] = (loss[i][1])  # val
        train_loss[i] = (loss[i][2].item())  # train

    plt.figure(figsize=(10, 6))
    plt.plot(epochs, train_loss, label='Train Loss', linestyle='-', linewidth=2)
    plt.plot(epochs, val_loss, label='Val Loss', linestyle='--', color='orange', linewidth=2)

    final_train_loss = train_loss[-1]
    final_val_loss = val_loss[-1]

    plt.text(epochs[-1] - 100, final_train_loss + 0.02, f'Final Train Loss: {final_train_loss:.4f}', 
             verticalalignment='bottom', horizontalalignment='left', weight='bold', color='C0', fontsize=12)
    plt.text(epochs[-1] - 100, final_val_loss + 0.08, f'Final Val Loss: {final_val_loss:.4f}', 
             verticalalignment='top', horizontalalignment='left', weight='bold', color='orange', fontsize=12)
    plt.title('2D Frequency=1 Resolution=40~80', fontsize=16)
    plt.xlabel('Epoch', fontsize=14)
    plt.ylabel('Loss', fontsize=14)
    plt.tick_params(axis='both', labelsize=12, width=2, colors='black')
    plt.grid(True)
    plt.legend(loc='upper right', fontsize=12, frameon=True, facecolor='white', edgecolor='black', framealpha=1, borderpad=1)
    plt.tight_layout()
    plt.savefig(f"results/val_loss_plot/loss_resolution40-80.jpg")

view_val_loss(3)