import matplotlib.pyplot as plt
from matplotlib.lines import Line2D


def plot_curve(x_values, y_values, y_values2=None, xlabel="x Value", ylabel="y Value"):
    fig, ax = plt.subplots(figsize=(12, 8))
    ax.plot(
        x_values,
        y_values,
        ls="-",
        marker=".",
        color="blue",
        alpha=0.8,
        lw=2,
    )
    if y_values2:
        ax.plot(
            x_values,
            y_values2,
            ls="--",
            marker=".",
            color="blue",
            alpha=0.8,
            lw=2,
        )
    ax.set_xlabel(xlabel, fontsize=18)
    ax.set_ylabel(ylabel, fontsize=18)

    plt.grid(True)

    plt.tight_layout()

    return fig


def plot_loss_curve(
    epochs,
    train_loss,
    test_loss,
    train_metric=None,
    test_metric=None,
    ylabel="Loss",
    metric_name="Accuracy",
):
    color_1 = "blue" if train_metric else "k"
    color_2 = "red"

    fig, ax = plt.subplots(figsize=(12, 8))
    ax.plot(
        epochs,
        train_loss,
        ls="--",
        color=color_1,
        alpha=0.4,
        lw=2,
    )
    ax.plot(
        epochs,
        test_loss,
        ls="-",
        color=color_1,
        alpha=0.8,
        lw=2,
    )
    ax.set_xlabel("Epochs", fontsize=18)
    ax.set_ylabel(ylabel, fontsize=18)

    custom_lines_1 = [
        Line2D([0], [0], ls="-", color="k", lw=1),
        Line2D([0], [0], ls="--", color="k", lw=1),
    ]
    labels_1 = ["Train", "Test"]
    custom_lines_2 = []
    labels_2 = []

    if train_metric and test_metric:
        ax2 = ax.twinx()
        ax2.plot(
            epochs,
            train_metric,
            ls="-.",
            color=color_2,
            alpha=0.4,
            lw=2,
        )
        ax2.plot(
            epochs,
            test_metric,
            ls="-",
            color=color_2,
            alpha=0.8,
            lw=2,
        )
        ax2.set_ylabel(metric_name, fontsize=18)

        custom_lines_2 = [
            Line2D([0], [0], color=color_1, lw=4),
            Line2D([0], [0], color=color_2, lw=4),
        ]
        labels_2 = [ylabel, metric_name]
        custom_lines_1.extend(custom_lines_2)
        labels_1.extend(labels_2)

    plt.grid(True)
    ax.legend(custom_lines_1, labels_1, fontsize=18)
    plt.tight_layout()

    return fig


def plot_aux_curve(
    epochs,
    aux_loss,
    ylabel="Aux Loss",
):

    fig, ax = plt.subplots(figsize=(12, 8))
    ax.plot(
        epochs[1:],
        aux_loss,
        ls="--",
        color="red",
        alpha=0.4,
        lw=2,
    )
    ax.set_xlabel("Epochs", fontsize=18)
    ax.set_ylabel(ylabel, fontsize=18)

    plt.grid(True)
    # ax.legend(fontsize=18)
    plt.tight_layout()

    return fig
