import torch
import numpy as np
from matplotlib import pyplot as plt
import matplotlib
import os

def plot_acc(num_epoch, name:str, 
             train_acc = None, test_acc = None, 
             train_loss = None, test_loss = None,
             root = "./figures"):
    matplotlib.rcParams['axes.unicode_minus']=False
    if not os.path.exists(root):
        os.mkdir(root)
    
    if train_acc:
        assert num_epoch == len(train_acc), "length of train acc does not equal to num_epochs"
        plt.plot(list(range(num_epoch)), train_acc, label="train acc")
    if test_acc:
        assert num_epoch == len(train_acc), "length of test acc does not equal to num_epochs"
        plt.plot(list(range(num_epoch)), test_acc, label="test acc")       
    if train_loss:
        assert num_epoch == len(train_loss), "length of train loss does not equal to num_epochs"
        plt.plot(list(range(num_epoch)), train_loss, label="train loss")
    if test_loss:
        assert num_epoch == len(test_loss), "length of test loss does not equal to num_epochs"
        plt.plot(list(range(num_epoch)), test_loss, label="test loss")
    plt.xticks(list(range(0, num_epoch, num_epoch // 10)))
    plt.xlabel("Epoch")
    plt.title(name)
    plt.legend()   
    plt.savefig(os.path.join(root, f"{name}.png"), dpi=300)
    plt.close()

