

import matplotlib.pyplot as plt


def save_loss_plot(losses, plots_path,client):
    """Saves loss plot

    Args:
        losses (dict): A dictionary contains list of losses
        plots_path (str): Path to use when saving loss plot

    """
    x_axis = list(range(len(losses["tloss_e"])))
    plt.plot(x_axis, losses["tloss_e"], c='r', label="Training")
    title = "Training"
    if len(losses["vloss_e"]) >= 1:
        # If validation loss is recorded less often, we need to adjust x-axis values by the factor of difference
        beta = len(losses["tloss_e"]) / len(losses["vloss_e"])
        x_axis = list(range(len(losses["vloss_e"])))
        # Adjust the values of x-axis by beta factor
        x_axis = [beta * i for i in x_axis]
        plt.plot(x_axis, losses["vloss_e"], c='b', label="Validation")
        title += " and Validation "
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend(loc="upper right")
    plt.title(title + " Loss", fontsize=12)
    # plt.tight_layout()
    plt.savefig(plots_path + "/"+ str(client) +"_loss.png")
    plt.clf()
