import matplotlib.pyplot as plt
import os
import torch
torch.set_printoptions(precision=8)

path_ngd = os.path.join(os.path.dirname(os.path.abspath(os.path.realpath(__file__))), "result", "20250323002504_ngd.pth")
path_sgd = os.path.join(os.path.dirname(os.path.abspath(os.path.realpath(__file__))), "result", "20250322235303_sgd.pth")
path_adam = os.path.join(os.path.dirname(os.path.abspath(os.path.realpath(__file__))), "result", "20250322233347_adam.pth")
path_save = os.path.join(os.path.dirname(os.path.abspath(os.path.realpath(__file__))), "result", "loss_heat_1d.png")

static_dict_ngd = torch.load(path_ngd)
loss_history_ngd = static_dict_ngd["loss_history"]
l2_error_ngd = static_dict_ngd["l2_error"]
static_dict_sgd = torch.load(path_sgd)
loss_history_sgd = static_dict_sgd["loss_history"]
l2_error_sgd = static_dict_sgd["l2_error"]
static_dict_adam = torch.load(path_adam)
loss_history_adam = static_dict_adam["loss_history"]
l2_error_adam = static_dict_adam["l2_error"]

x_range_ngd = list(range(10, 210, 10))
x_range_other = list(range(10, 10010, 10))
plt.plot(x_range_ngd, loss_history_ngd, label="NGD")
plt.plot(x_range_other, loss_history_sgd, label="SGD")
plt.plot(x_range_other, loss_history_adam, label="Adam")
plt.ylim(1e-7, 1e-1)
plt.yscale("log")
plt.legend()
plt.xlabel("epochs")
plt.ylabel("loss")
plt.savefig(path_save)
plt.show()

print(l2_error_ngd, l2_error_sgd, l2_error_adam)