import matplotlib.pyplot as plt
import matplotlib as mpl
from eos_line_search.experiment import *
import os
import numpy as np

plot_markers = ["o", "v", "^", "s", "p", "P", "*", "h", "X", "D"]


def plot_results(runs, path):
    mpl.rcParams["axes.spines.right"] = False
    mpl.rcParams["axes.spines.top"] = False
    plt.rcParams.update(
        {
            "axes.titlesize": 22,
            "axes.labelsize": 19,
            "legend.fontsize": 17,
            "lines.linewidth": 3,
            "lines.markersize": 14,
            "xtick.labelsize": 15,
            "ytick.labelsize": 15,
        }
    )
    plot_settings = {"markevery": 50, "linestyle": "solid"}

    ### Main Paper Plots ###
    metric_num = 0
    plotevery = 10
    # Plot Training Loss for all line searches
    plot_settings["y_metric"] = "Training Loss"
    plot_settings["ncols"] = 4
    plot_settings["plotevery"] = plotevery
    for j, run in enumerate(runs):
        if (
            run.optimizer.opt_name == "SLS"
            and (not (run.optimizer.forward_option == 1))
        ) or run.optimizer.opt_name == "PoNoS":
            plot_settings["label"] = setup_labels(run)
            plot_settings["marker"] = select_marker(setup_labels(run))
            plot_settings["colour"] = select_colour(setup_labels(run))
            plot(
                run,
                np.array(run.plot_data["Training Loss"]),
                plot_settings,
                path,
                metric_num,
                "training_loss_0",
            )

    # Plot Training Loss for new line searches and constant step sizes
    plot_settings["y_metric"] = "Training Loss"
    plot_settings["ncols"] = 5
    plot_settings["plotevery"] = plotevery
    metric_num += 1
    for j, run in enumerate(runs):
        if (
            (run.optimizer.opt_name == "SLS" and run.optimizer.forward_option == 4)
            or (run.optimizer.opt_name == "PoNoS" and run.optimizer.forward_option == 4)
            or run.optimizer.opt_name == "constant_stepsize_GD"
        ):
            if run.optimizer.opt_name == "constant_stepsize_GD" and j == 0:
                plot_settings["label"] = setup_labels(run) + "-large"
            elif run.optimizer.opt_name == "constant_stepsize_GD" and j == 1:
                plot_settings["label"] = setup_labels(run) + "-medium"
            elif run.optimizer.opt_name == "constant_stepsize_GD" and j == 2:
                plot_settings["label"] = setup_labels(run) + "-small"
            else:
                plot_settings["label"] = setup_labels(run)
            plot_settings["marker"] = select_marker(plot_settings["label"])
            plot_settings["colour"] = select_colour(plot_settings["label"])
            plot(
                run,
                np.array(run.plot_data["Training Loss"]),
                plot_settings,
                path,
                metric_num,
                "training_loss_1",
            )

    # Plot Sharpness for new line searches and constant step sizes
    plot_settings["y_metric"] = "Sharpness"
    plot_settings["ncols"] = 5
    plot_settings["plotevery"] = plotevery
    metric_num += 1
    for j, run in enumerate(runs):
        if (
            (run.optimizer.opt_name == "SLS" and run.optimizer.forward_option == 4)
            or (run.optimizer.opt_name == "PoNoS" and run.optimizer.forward_option == 4)
            or run.optimizer.opt_name == "constant_stepsize_GD"
        ):
            if run.optimizer.opt_name == "constant_stepsize_GD" and j == 0:
                plot_settings["label"] = setup_labels(run) + "-large"
            elif run.optimizer.opt_name == "constant_stepsize_GD" and j == 1:
                plot_settings["label"] = setup_labels(run) + "-medium"
            elif run.optimizer.opt_name == "constant_stepsize_GD" and j == 2:
                plot_settings["label"] = setup_labels(run) + "-small"
            else:
                plot_settings["label"] = setup_labels(run)
            plot_settings["marker"] = select_marker(plot_settings["label"])
            plot_settings["colour"] = select_colour(plot_settings["label"])
            plot(
                run,
                np.array(run.plot_data["Sharpness"]),
                plot_settings,
                path,
                metric_num,
                "sharpness_0",
            )

    # Plot Sharpness * Step Size for new line searches and constant step sizes
    plot_settings["y_metric"] = "Sharpness * Step Size"
    plot_settings["ncols"] = 5
    plot_settings["plotevery"] = plotevery
    metric_num += 1
    avg_over = 50
    for j, run in enumerate(runs):
        if (
            (run.optimizer.opt_name == "SLS" and run.optimizer.forward_option == 4)
            or (run.optimizer.opt_name == "PoNoS" and run.optimizer.forward_option == 4)
            or run.optimizer.opt_name == "constant_stepsize_GD"
        ):
            if run.optimizer.opt_name == "constant_stepsize_GD" and j == 0:
                plot_settings["label"] = setup_labels(run) + "-large"
            elif run.optimizer.opt_name == "constant_stepsize_GD" and j == 1:
                plot_settings["label"] = setup_labels(run) + "-medium"
            elif run.optimizer.opt_name == "constant_stepsize_GD" and j == 2:
                plot_settings["label"] = setup_labels(run) + "-small"
            else:
                plot_settings["label"] = setup_labels(run)
            plot_settings["marker"] = select_marker(plot_settings["label"])
            plot_settings["colour"] = select_colour(plot_settings["label"])
            plt.axhline(2, linestyle="--", color="k")
            plot(
                run,
                np.convolve(
                    np.array(run.plot_data["Final Step Size"])
                    * np.array(run.plot_data["Sharpness"]),
                    np.ones(avg_over) / avg_over,
                    mode="valid",
                ),
                plot_settings,
                path,
                metric_num,
                "sharpness_step_size_0",
            )

    # Plot Sharpness, L_approx for "noTune" line searches
    plot_settings["y_metric"] = "Sharpness, L_approx"
    plot_settings["ncols"] = 4
    plot_settings["plotevery"] = plotevery
    metric_num += 1
    for j, run in enumerate(runs):
        if (run.optimizer.opt_name == "SLS" and run.optimizer.forward_option == 4) or (
            run.optimizer.opt_name == "PoNoS" and run.optimizer.forward_option == 4
        ):
            plot_settings["label"] = setup_labels(run) + ": Sharpness"
            plot_settings["marker"] = ""
            plot_settings["linestyle"] = "solid"
            plot_settings["colour"] = select_colour(setup_labels(run))
            plot(
                run,
                np.array(run.plot_data["Sharpness"]),
                plot_settings,
                path,
                metric_num,
                "sharpness_Lapprox",
            )

    for j, run in enumerate(runs):
        if (run.optimizer.opt_name == "SLS" and run.optimizer.forward_option == 4) or (
            run.optimizer.opt_name == "PoNoS" and run.optimizer.forward_option == 4
        ):
            plot_settings["label"] = setup_labels(run) + ": L_approx"
            plot_settings["marker"] = ""
            plot_settings["linestyle"] = "dotted"
            plot_settings["colour"] = select_colour(setup_labels(run))
            plot(
                run,
                np.array(run.plot_data["a"]),
                plot_settings,
                path,
                metric_num,
                "sharpness_Lapprox",
            )

    # Plot Relative Error of L_BB1 and L_BB2 for "noTune" line searches
    plot_settings["y_metric"] = "|Sharpness - L_BB| $\div$ Sharpness"
    plot_settings["ncols"] = 4
    plot_settings["plotevery"] = plotevery
    metric_num += 1
    avg_over = 10
    for j, run in enumerate(runs):
        if (run.optimizer.opt_name == "SLS" and run.optimizer.forward_option == 4) or (
            run.optimizer.opt_name == "PoNoS" and run.optimizer.forward_option == 4
        ):
            plot_settings["label"] = setup_labels(run) + ": L_BB1"
            plot_settings["marker"] = ""
            plot_settings["linestyle"] = "solid"
            plot_settings["colour"] = select_colour(setup_labels(run))
            plot(
                run,
                np.convolve(
                    np.abs(
                        np.array(run.plot_data["Sharpness"])
                        - np.array(run.plot_data["Approx 9"])
                    )
                    / np.array(run.plot_data["Sharpness"]),
                    np.ones(avg_over) / avg_over,
                    mode="valid",
                ),
                plot_settings,
                path,
                metric_num,
                "relative_error",
            )

    for j, run in enumerate(runs):
        if (run.optimizer.opt_name == "SLS" and run.optimizer.forward_option == 4) or (
            run.optimizer.opt_name == "PoNoS" and run.optimizer.forward_option == 4
        ):
            plot_settings["label"] = setup_labels(run) + ": L_BB2"
            plot_settings["marker"] = ""
            plot_settings["linestyle"] = "dotted"
            plot_settings["colour"] = select_colour(setup_labels(run))
            plot(
                run,
                np.convolve(
                    np.abs(
                        np.array(run.plot_data["Sharpness"])
                        - np.array(run.plot_data["Approx 8"])
                    )
                    / np.array(run.plot_data["Sharpness"]),
                    np.ones(avg_over) / avg_over,
                    mode="valid",
                ),
                plot_settings,
                path,
                metric_num,
                "relative_error",
            )

    ### Appendix Plots ###
    plot_settings = {
        "markevery": 100,
        "linestyle": "solid",
        "plotevery": 10,
    }

    # Plot Sharpness for all line searches
    plot_settings["y_metric"] = "Sharpness"
    plot_settings["ncols"] = 4
    plot_settings["plotevery"] = plotevery
    metric_num += 1
    for j, run in enumerate(runs):
        if (
            run.optimizer.opt_name == "SLS"
            and (not (run.optimizer.forward_option == 1))
        ) or run.optimizer.opt_name == "PoNoS":
            plot_settings["label"] = setup_labels(run)
            plot_settings["marker"] = select_marker(setup_labels(run))
            plot_settings["colour"] = select_colour(setup_labels(run))
            plot(
                run,
                np.array(run.plot_data["Sharpness"]),
                plot_settings,
                path,
                metric_num,
                "sharpness_1",
            )

    # Plot Sharpness * Step Size for all line searches
    plot_settings["y_metric"] = "Sharpness * Step Size"
    plot_settings["ncols"] = 4
    plot_settings["plotevery"] = plotevery
    metric_num += 1
    avg_over = 50
    for j, run in enumerate(runs):
        if (
            run.optimizer.opt_name == "SLS"
            and (not (run.optimizer.forward_option == 1))
        ) or run.optimizer.opt_name == "PoNoS":
            plot_settings["label"] = setup_labels(run)
            plot_settings["marker"] = select_marker(setup_labels(run))
            plot_settings["colour"] = select_colour(setup_labels(run))
            plt.axhline(2, linestyle="--", color="k")
            plot(
                run,
                np.convolve(
                    np.array(run.plot_data["Final Step Size"])
                    * np.array(run.plot_data["Sharpness"]),
                    np.ones(avg_over) / avg_over,
                    mode="valid",
                ),
                plot_settings,
                path,
                metric_num,
                "sharpness_step_size_1",
            )

    # Plot Step Size for all line searches
    plot_settings["y_metric"] = "Step Size"
    plot_settings["ncols"] = 4
    plot_settings["plotevery"] = plotevery
    metric_num += 1
    for j, run in enumerate(runs):
        if (
            run.optimizer.opt_name == "SLS"
            and (not (run.optimizer.forward_option == 1))
        ) or run.optimizer.opt_name == "PoNoS":
            plot_settings["label"] = setup_labels(run)
            plot_settings["marker"] = select_marker(setup_labels(run))
            plot_settings["colour"] = select_colour(setup_labels(run))
            plot(
                run,
                np.array(run.plot_data["Final Step Size"]),
                plot_settings,
                path,
                metric_num,
                "step_size_0",
            )

    # Plot Step Size for new line searches and constant step sizes
    plot_settings["y_metric"] = "Step Size"
    plot_settings["ncols"] = 4
    plot_settings["plotevery"] = plotevery
    metric_num += 1
    for j, run in enumerate(runs):
        if (
            (run.optimizer.opt_name == "SLS" and run.optimizer.forward_option == 4)
            or (run.optimizer.opt_name == "PoNoS" and run.optimizer.forward_option == 4)
            or run.optimizer.opt_name == "constant_stepsize_GD"
        ):
            if run.optimizer.opt_name == "constant_stepsize_GD" and j == 0:
                plot_settings["label"] = setup_labels(run) + "-large"
            elif run.optimizer.opt_name == "constant_stepsize_GD" and j == 1:
                plot_settings["label"] = setup_labels(run) + "-medium"
            elif run.optimizer.opt_name == "constant_stepsize_GD" and j == 2:
                plot_settings["label"] = setup_labels(run) + "-small"
            else:
                plot_settings["label"] = setup_labels(run)
            plot_settings["marker"] = select_marker(plot_settings["label"])
            plot_settings["colour"] = select_colour(plot_settings["label"])
            plot(
                run,
                np.array(run.plot_data["Final Step Size"]),
                plot_settings,
                path,
                metric_num,
                "step_size_1",
            )

    # Plot Relative Error of L_approx and 2/eta for "noTune" line searches
    plot_settings["y_metric"] = "|Sharpness - approx| $\div$ Sharpness"
    plot_settings["ncols"] = 4
    plot_settings["plotevery"] = plotevery
    metric_num += 1
    avg_over = 10
    for j, run in enumerate(runs):
        if (run.optimizer.opt_name == "SLS" and run.optimizer.forward_option == 4) or (
            run.optimizer.opt_name == "PoNoS" and run.optimizer.forward_option == 4
        ):
            plot_settings["label"] = setup_labels(run) + ": L_approx"
            plot_settings["marker"] = ""
            plot_settings["linestyle"] = "solid"
            plot_settings["colour"] = select_colour(setup_labels(run))
            plot(
                run,
                np.convolve(
                    np.abs(
                        np.array(run.plot_data["Sharpness"])
                        - np.array(run.plot_data["a"])
                    )
                    / np.array(run.plot_data["Sharpness"]),
                    np.ones(avg_over) / avg_over,
                    mode="valid",
                ),
                plot_settings,
                path,
                metric_num,
                "relative_error_2",
            )

    for j, run in enumerate(runs):
        if (run.optimizer.opt_name == "SLS" and run.optimizer.forward_option == 4) or (
            run.optimizer.opt_name == "PoNoS" and run.optimizer.forward_option == 4
        ):
            plot_settings["label"] = setup_labels(run) + ": 2 $\div$ Step Size"
            plot_settings["marker"] = ""
            plot_settings["linestyle"] = "dotted"
            plot_settings["colour"] = select_colour(setup_labels(run))
            plot(
                run,
                np.convolve(
                    np.abs(
                        np.array(run.plot_data["Sharpness"])
                        - (2 / np.array(run.plot_data["Final Step Size"]))
                    )
                    / np.array(run.plot_data["Sharpness"]),
                    np.ones(avg_over) / avg_over,
                    mode="valid",
                ),
                plot_settings,
                path,
                metric_num,
                "relative_error_2",
            )

    plot_settings = {
        "markevery": 100,
        "linestyle": "solid",
        "plotevery": 10,
    }
    # Plot Relative Error of L_BB1, L_BB2, L_BB3, and L_approx for Armijo-noTune
    plot_settings["y_metric"] = "|Sharpness - approx| $\div$ Sharpness"
    plot_settings["ncols"] = 4
    plot_settings["plotevery"] = plotevery
    metric_num += 1
    avg_over = 10
    for j, run in enumerate(runs):
        if run.optimizer.opt_name == "SLS" and run.optimizer.forward_option == 4:
            plot_settings["label"] = "L_BB1"
            plot_settings["marker"] = ""
            # plot_settings["linestyle"] = "solid"
            plot_settings["colour"] = "k"  # select_colour(setup_labels(run))
            plot(
                run,
                np.convolve(
                    np.abs(
                        np.array(run.plot_data["Sharpness"])
                        - np.array(run.plot_data["Approx 9"])
                    )
                    / np.array(run.plot_data["Sharpness"]),
                    np.ones(avg_over) / avg_over,
                    mode="valid",
                ),
                plot_settings,
                path,
                metric_num,
                "relative_error_3",
            )

    for j, run in enumerate(runs):
        if run.optimizer.opt_name == "SLS" and run.optimizer.forward_option == 4:
            plot_settings["label"] = "L_BB2"
            plot_settings["marker"] = ""
            # plot_settings["linestyle"] = "dashed"
            plot_settings["colour"] = "tab:red"  # select_colour(setup_labels(run))
            plot(
                run,
                np.convolve(
                    np.abs(
                        np.array(run.plot_data["Sharpness"])
                        - np.array(run.plot_data["Approx 8"])
                    )
                    / np.array(run.plot_data["Sharpness"]),
                    np.ones(avg_over) / avg_over,
                    mode="valid",
                ),
                plot_settings,
                path,
                metric_num,
                "relative_error_3",
            )

    for j, run in enumerate(runs):
        if run.optimizer.opt_name == "SLS" and run.optimizer.forward_option == 4:
            plot_settings["label"] = "L_BB3"
            plot_settings["marker"] = ""
            # plot_settings["linestyle"] = "dotted"
            plot_settings["colour"] = "k"  # select_colour(setup_labels(run))
            plot(
                run,
                np.convolve(
                    np.abs(
                        np.array(run.plot_data["Sharpness"])
                        - np.array(run.plot_data["Approx 7"])
                    )
                    / np.array(run.plot_data["Sharpness"]),
                    np.ones(avg_over) / avg_over,
                    mode="valid",
                ),
                plot_settings,
                path,
                metric_num,
                "relative_error_3",
            )

    for j, run in enumerate(runs):
        if run.optimizer.opt_name == "SLS" and run.optimizer.forward_option == 4:
            plot_settings["label"] = "L_approx"
            plot_settings["marker"] = ""
            # plot_settings["linestyle"] = "dashdot"
            plot_settings["colour"] = "k"  # select_colour(setup_labels(run))
            plot(
                run,
                np.convolve(
                    np.abs(
                        np.array(run.plot_data["Sharpness"])
                        - np.array(run.plot_data["a"])
                    )
                    / np.array(run.plot_data["Sharpness"]),
                    np.ones(avg_over) / avg_over,
                    mode="valid",
                ),
                plot_settings,
                path,
                metric_num,
                "relative_error_3",
            )

    # Plot Relative Error of L_BB1, L_BB2, L_BB3, and L_approx for NLS
    plot_settings["y_metric"] = "|Sharpness - approx| $\div$ Sharpness"
    plot_settings["ncols"] = 4
    plot_settings["plotevery"] = plotevery
    metric_num += 1
    avg_over = 10
    for j, run in enumerate(runs):
        if run.optimizer.opt_name == "PoNoS" and run.optimizer.forward_option == 4:
            plot_settings["label"] = "L_BB1"
            plot_settings["marker"] = ""
            # plot_settings["linestyle"] = "solid"
            plot_settings["colour"] = "k"  # select_colour(setup_labels(run))
            plot(
                run,
                np.convolve(
                    np.abs(
                        np.array(run.plot_data["Sharpness"])
                        - np.array(run.plot_data["Approx 9"])
                    )
                    / np.array(run.plot_data["Sharpness"]),
                    np.ones(avg_over) / avg_over,
                    mode="valid",
                ),
                plot_settings,
                path,
                metric_num,
                "relative_error_4",
            )

    for j, run in enumerate(runs):
        if run.optimizer.opt_name == "PoNoS" and run.optimizer.forward_option == 4:
            plot_settings["label"] = "L_BB2"
            plot_settings["marker"] = ""
            # plot_settings["linestyle"] = "dashed"
            plot_settings["colour"] = "tab:red"  # select_colour(setup_labels(run))
            plot(
                run,
                np.convolve(
                    np.abs(
                        np.array(run.plot_data["Sharpness"])
                        - np.array(run.plot_data["Approx 8"])
                    )
                    / np.array(run.plot_data["Sharpness"]),
                    np.ones(avg_over) / avg_over,
                    mode="valid",
                ),
                plot_settings,
                path,
                metric_num,
                "relative_error_4",
            )

    for j, run in enumerate(runs):
        if run.optimizer.opt_name == "PoNoS" and run.optimizer.forward_option == 4:
            plot_settings["label"] = "L_BB3"
            plot_settings["marker"] = ""
            # plot_settings["linestyle"] = "dotted"
            plot_settings["colour"] = "k"  # select_colour(setup_labels(run))
            plot(
                run,
                np.convolve(
                    np.abs(
                        np.array(run.plot_data["Sharpness"])
                        - np.array(run.plot_data["Approx 7"])
                    )
                    / np.array(run.plot_data["Sharpness"]),
                    np.ones(avg_over) / avg_over,
                    mode="valid",
                ),
                plot_settings,
                path,
                metric_num,
                "relative_error_4",
            )

    for j, run in enumerate(runs):
        if run.optimizer.opt_name == "PoNoS" and run.optimizer.forward_option == 4:
            plot_settings["label"] = "L_approx"
            plot_settings["marker"] = ""
            # plot_settings["linestyle"] = "dashdot"
            plot_settings["colour"] = "k"  # select_colour(setup_labels(run))
            plot(
                run,
                np.convolve(
                    np.abs(
                        np.array(run.plot_data["Sharpness"])
                        - np.array(run.plot_data["a"])
                    )
                    / np.array(run.plot_data["Sharpness"]),
                    np.ones(avg_over) / avg_over,
                    mode="valid",
                ),
                plot_settings,
                path,
                metric_num,
                "relative_error_4",
            )

    plt.close("all")


def plot(run, run_data, plot_settings, path, metric_num, filename):
    plt.figure(num=metric_num, figsize=(10, 5))
    plt.plot(
        np.array(range(0, len(run_data)))[
            starting_point(plot_settings["y_metric"]) : len(run_data) : plot_settings[
                "plotevery"
            ]
        ],
        run_data[
            starting_point(plot_settings["y_metric"]) : len(run_data) : plot_settings[
                "plotevery"
            ]
        ],
        markevery=plot_settings["markevery"],
        label=plot_settings["label"],
        marker=plot_settings["marker"],
        linestyle=plot_settings["linestyle"],
        color=(
            plot_settings["colour"]
            if (not (plot_settings["colour"] == "default"))
            else None
        ),
    )

    log_scale(plot_settings["y_metric"])
    plt.ylabel(setup_ylabel(plot_settings["y_metric"]))
    x_lim(plot_settings["y_metric"], len(run_data))
    y_lim(plot_settings["y_metric"])
    if run.num_batches == 1:
        plt.xlabel("Iteration")
    else:
        plt.xlabel("Epoch")
    plt.title(setup_title(run))

    """
    plt.legend(
        bbox_to_anchor=(
            (0.5, -0.3) if (not plot_settings["ncols"] == 2) else (0.5, -0.375)
        ),
        loc="lower center",
        borderaxespad=0,
        ncol=plot_settings["ncols"],
    )    
    """

    plt.savefig(
        os.path.join(
            path,
            "plots",
            run.model.model_type,
            filename + ".pdf",
        ),
        bbox_inches="tight",
    )


def setup_labels(run):
    if run.plot_metrics.label == "Optimizer":
        opt_name = run.optimizer.opt_name
        forward_option = run.optimizer.forward_option
        if opt_name == "SLS" and forward_option == 2:
            label = "Armijo"
        elif opt_name == "SLS" and forward_option == 4:
            label = "Armijo-noTune"
        elif opt_name == "PoNoS" and forward_option == 0:
            label = "PoNLS"
        elif opt_name == "PoNoS" and forward_option == 4:
            label = "NLS-noTune"
        elif opt_name == "constant_stepsize_GD":
            label = "GD"
        else:
            label = opt_name
    else:
        raise ValueError("Not a valid label for plot")

    return label


def setup_ylabel(ymetric):
    if ymetric == "a":
        ylabel = "L_approx"
    elif ymetric == "Approx 9":
        ylabel = "L_BB1"
    elif ymetric == "Approx 8":
        ylabel = "L_BB2"
    elif ymetric == "Final Step Size":
        ylabel = "Step Size"
    else:
        ylabel = ymetric

    return ylabel


def setup_title(run, include_model=True, include_dataset=False):
    title = ""
    if include_dataset:
        if (
            run.dataset.name == "synthetic_regression_interpolate"
            or run.dataset.name == "synthetic_regression"
        ):
            dataset = "Synthetic Regression Data"
        else:
            dataset = run.dataset.name

    if include_model:
        if run.model.model_type == "linear_regression":
            model = "Linear Regression"
        elif run.model.model_type == "logistic_regression":
            model = "Logistic Regression"
        else:
            model = run.model.model_type

    if include_dataset == False and include_model:
        title = model
    elif include_dataset and include_model == False:
        title = dataset
    else:
        title = dataset + " - " + model

    return title


def starting_point(metric):
    return 0


def log_scale(metric):
    if metric == "Training Loss" or metric == "Gradient Norm" or metric == "Step Size":
        plt.yscale("log")


def x_lim(metric, data_length):
    plt.xlim(0, data_length)


def y_lim(metric):
    if metric == "Sharpness * Step Size":
        plt.ylim(bottom=0, top=8)
    elif metric == "|Sharpness - L_BB| $\div$ Sharpness":
        plt.ylim(bottom=0, top=1)
    elif metric == "Sharpness, L_approx":
        plt.ylim(bottom=0)
    elif metric == "|Sharpness - approx| $\div$ Sharpness":
        plt.ylim(bottom=0, top=1)


def select_colour(label):
    if label == "Armijo-noTune":
        colour = "k"
    elif label == "NLS-noTune":
        colour = "tab:red"
    elif label == "Armijo":
        colour = "k"
    elif label == "PoNLS":
        colour = "k"
    elif label == "GD-small":
        colour = "k"
    elif label == "GD-medium":
        colour = "k"
    elif label == "GD-large":
        colour = "k"
    else:
        colour = "default"

    return colour


def select_marker(label):
    if label == "Armijo-noTune":
        marker = "D"
    elif label == "NLS-noTune":
        marker = "X"
    elif label == "Armijo":
        marker = "s"
    elif label == "PoNLS":
        marker = "P"
    elif label == "GD-small":
        marker = "^"
    elif label == "GD-medium":
        marker = "*"
    elif label == "GD-large":
        marker = "o"
    else:
        marker = ""

    return marker
