"""Plotting utilities for loss logs."""
from pathlib import Path
from matplotlib.figure import Figure
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas


def plot_loss_log(
        loss_log: list[float],
        label: str,
        output_path: Path,
        ):
    # Create matplotlib figure
    fig = Figure()
    _ = FigureCanvas(fig)
    ax = fig.add_subplot()

    # Plot loss log
    x = list(range(len(loss_log)))
    y = loss_log
    ax.plot(x, y, label=label)

    # Save figure
    fig.legend()
    ax.set_xlabel('Step')
    ax.set_ylabel('Loss')
    fig.savefig(output_path)


def plot_loss_logs(
        loss_logs: list[list[float]],
        labels: list[str],
        output_path: Path,
        ):
    # Create matplotlib figure
    fig = Figure()
    _ = FigureCanvas(fig)
    ax = fig.add_subplot()

    # Plot loss log
    for loss_log, label in zip(loss_logs, labels):
        x = list(range(len(loss_log)))
        y = loss_log
        ax.plot(x, y, label=label)

    # Save figure
    fig.legend()
    ax.set_xlabel('Eval mean reward')
    ax.set_ylabel('Total train timesteps')
    fig.savefig(output_path)
