import matplotlib.pyplot as plt
import matplotlib as mpl
from eos_line_search.experiment import *
import os

plot_markers = ["o", "v", "^", "8", "s", "p", "P", "*", "h", "X", "D", "d"]


def plot_results(runs, path):
    mpl.rcParams["axes.spines.right"] = False
    mpl.rcParams["axes.spines.top"] = False
    plt.rcParams.update(
        {
            "axes.titlesize": 22,
            "axes.labelsize": 19,
            "legend.fontsize": 17,
            "lines.linewidth": 3,
            "lines.markersize": 14,
            "xtick.labelsize": 15,
            "ytick.labelsize": 15,
        }
    )
    plot_settings = {
        "markevery": 50,
        "linestyle": "solid",
        "plotevery": 1,
        "ncols": 4,
    }

    # create list of all metrics to plot
    metrics = []
    for run in runs:
        for metric, _ in run.plot_data.items():
            if not (metric in metrics):
                metrics.append(metric)

    # generate sharpness * final step size metric
    for run in runs:
        if (
            "Eigenvalue 1" in run.plot_data.keys()
            and "Final Step Size" in run.plot_data.keys()
        ):
            if not ("Eigenvalue 1 * Final Step Size" in metrics):
                metrics.append("Eigenvalue 1 * Final Step Size")
                for run in runs:
                    if (
                        len(run.plot_data["Eigenvalue 1"])
                        != run.plot_data["Final Step Size"]
                    ):
                        new_data = [
                            step
                            for i, step in enumerate(run.plot_data["Final Step Size"])
                            if i % 100 == 0
                        ]
                        new_data.append(run.plot_data["Final Step Size"][-1])
                        run.plot_data["Eigenvalue 1 * Final Step Size"] = np.array(
                            run.plot_data["Eigenvalue 1"]
                        ) * np.array(new_data)
                    else:
                        run.plot_data["Eigenvalue 1 * Final Step Size"] = np.array(
                            run.plot_data["Eigenvalue 1"]
                        ) * np.array(run.plot_data["Final Step Size"])

    # plot metrics for each run
    metric_num = 0
    for _, metric in enumerate(metrics):
        plot_settings["x_metric"] = "Iteration"
        plot_settings["y_metric"] = metric
        metric_num += 1
        for j, run in enumerate(runs):
            if metric in run.plot_data.keys():
                plot_settings["label"] = setup_labels(run)
                plot_settings["marker"] = plot_markers[j]
                x_data = list(range(0, len(run.plot_data[metric])))
                y_data = run.plot_data[metric]
                plot(run, x_data, y_data, plot_settings, path, metric_num, metric)

    ### additional specialized plots

    # plot c for line search
    x_data = []
    y_data = []
    plot_settings["x_metric"] = "c"
    plot_settings["y_metric"] = "Training Loss"
    plot_settings["markevery"] = 1
    metric_num += 1
    for j, run in enumerate(runs):
        if run.optimizer.opt_name == "PoNoS" and (run.optimizer.forward_option == 7):
            x_data.append(run.optimizer.c)
            y_data.append(run.plot_data["Training Loss"][-1])
            plot(run, x_data, y_data, plot_settings, path, metric_num, "compare-c")

    # plot c step size for SAM
    x_data = []
    y_data = []
    plot_settings["x_metric"] = "Step Size"
    plot_settings["y_metric"] = "Training Loss"
    plot_settings["markevery"] = 1
    metric_num += 1
    for j, run in enumerate(runs):
        if run.optimizer.opt_name == "SAM":
            x_data.append(run.optimizer.step_size)
            y_data.append(run.plot_data["Training Loss"][-1])
            plot(
                run, x_data, y_data, plot_settings, path, metric_num, "compare-stepsize"
            )

    plt.close("all")
    return


def plot(run, x_data, y_data, plot_settings, path, j, filename):
    plt.figure(j, figsize=(10, 5))
    plt.plot(
        x_data,
        y_data,
        markevery=plot_settings["markevery"],
        label=plot_settings["label"],
        marker=plot_settings["marker"],
        linestyle=plot_settings["linestyle"],
    )
    plt.xlabel(plot_settings["x_metric"])
    log_scale(plot_settings["x_metric"], "x")
    plt.ylabel(setup_ylabel(plot_settings["y_metric"]))
    log_scale(plot_settings["y_metric"], "y")
    plt.xlim(left=0, right=len(x_data))
    plt.title(setup_title(run))

    plt.legend(
        bbox_to_anchor=(0.5, -0.2),
        loc="lower center",
        borderaxespad=0,
        ncol=plot_settings["ncols"],
    )

    plt.savefig(
        os.path.join(
            path,
            "plots",
            run.model.model_type,
            filename + ".pdf",
        ),
        bbox_inches="tight",
    )
    return


def setup_labels(run):
    if run.plot_metrics.label == "Optimizer":
        opt_name = run.optimizer.opt_name
        forward_option = run.optimizer.forward_option
        if opt_name == "SLS" and (forward_option == 1 or forward_option == 2):
            label = "Armijo"
        elif opt_name == "SLS" and forward_option == 4:
            label = "Armijo-free"
        elif opt_name == "PoNoS" and forward_option == 0:
            label = "PoNLS"
        elif opt_name == "PoNoS" and forward_option == 4:
            label = "NLS-free"
        elif opt_name == "constant_stepsize_GD":
            label = "GD-" + f"{run.optimizer.step_size:.3f}"
        elif opt_name == "PoNoS" and forward_option == 7:
            label = "NLS-new"
        elif opt_name == "CDAT":
            label = "CDAT-" + f"{run.optimizer.c:.1f}"
        elif opt_name == "SAM":
            label = "SAM"
        else:
            label = opt_name
    else:
        raise ValueError("Not a valid label for plot")

    return label


def setup_ylabel(ymetric):
    if ymetric == "a":
        ylabel = "L_approx"
    elif ymetric == "Approx 9":
        ylabel = "L_BB1"
    elif ymetric == "Approx 8":
        ylabel = "L_BB2"
    elif ymetric == "Final Step Size":
        ylabel = "Step Size"
    else:
        ylabel = ymetric

    return ylabel


def setup_title(run, include_model=True, include_dataset=True):
    title = ""
    if include_dataset:
        if (
            run.dataset.name == "synthetic_regression_interpolate"
            or run.dataset.name == "synthetic_regression"
        ):
            dataset = "Synthetic Regression Data"
        else:
            dataset = run.dataset.name

    if include_model:
        if run.model.model_type == "linear_regression":
            model = "Linear Regression"
        elif run.model.model_type == "logistic_regression":
            model = "Logistic Regression"
        else:
            model = run.model.model_type

    if include_dataset == False and include_model:
        title = model
    elif include_dataset and include_model == False:
        title = dataset
    else:
        title = dataset + " - " + model

    return title


def log_scale(metric, axis):
    if (
        metric == "Training Loss"
        or metric == "Gradient Norm"
        or metric == "Step Size"
        or metric == "c"
    ):
        if axis == "x":
            plt.xscale("log")
        elif axis == "y":
            plt.yscale("log")
        else:
            raise ValueError("Not a valid axis for plot")
    return
