import os
import matplotlib.pyplot as plt
import matplotlib_inline.backend_inline
from utils.utils import load_variable

def use_svg_display():

    matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

def set_figsize(figsize=(3.5, 2.5), fontsize=12):

    use_svg_display()
    plt.rcParams['figure.figsize'] = figsize
    plt.rcParams['font.size'] = fontsize

set_figsize()

def plot_training_curves(save_dir, filename):
    loss_list = load_variable(os.path.join(save_dir, "loss_list.pkl"))
    grad_norm_list = load_variable(os.path.join(save_dir, "grad_norm_list.pkl"))
    epcoh_list = load_variable(os.path.join(save_dir, "epoch_list.pkl"))
    plt.plot(loss_list, '-', lw=2, label='Training Loss')
    plt.plot(grad_norm_list, '-', lw=2, label='Gradient Norm')
    
    plt.ylabel("Loss")
    plt.xlabel("Epoch")
    plt.yscale('log')

    plt.legend(bbox_to_anchor=(1.01, 1), loc='upper left')
    
    save_fig_dir = os.path.join(save_dir, "figures/")
    os.makedirs(save_fig_dir, exist_ok=True)
    plt.savefig(os.path.join(save_fig_dir, filename), dpi = 500, bbox_inches='tight')

    plt.show()