"""Plot the conditions for fast convergence of NGD for different cases."""

from os import makedirs, path

from matplotlib import pyplot as plt
from seaborn import color_palette

from accumulate_gram import accumulate_gram_matrix_condition
from accumulate_linear import accumulate_linearity_condition
from cases import (CASES_VARY_DATA, CIFAR10SUB400_WRN, SYNTHETIC_DEEP,
                   SYNTHETIC_LESS_DEEP, SYNTHETIC_SHALLOW)

HEREDIR = path.dirname(path.abspath(__file__))
FIGDIR = path.join(HEREDIR, "figs")
makedirs(FIGDIR, exist_ok=True)


def group_plot():
    """Create and save a group plot of the convergence conditions."""
    ###############################################################################
    #                                 Plot styling                                #
    ###############################################################################
    # NeurIPS & ICLR
    WIDTH = 5.50107
    HEIGHT = 9.00177
    SCRIPTSIZE = 7  # FOOTNOTESIZE = 9

    params = {
        "text.usetex": False,
        "font.size": SCRIPTSIZE,
        "font.family": "Times New Roman",
        "font.serif": ["Times New Roman"],
        "mathtext.fontset": "cm",
        "axes.linewidth": 0.5,
        "axes.titlesize": SCRIPTSIZE + 1,
        "axes.labelsize": SCRIPTSIZE + 1,
        "lines.linewidth": 1,
        "xtick.major.size": 1.5,
        "ytick.major.size": 1.5,
        "xtick.major.width": 0.5,
        "ytick.major.width": 0.5,
        "legend.fontsize": SCRIPTSIZE - 1,
        "pdf.fonttype": 42,
        "text.usetex": True,
        "text.latex.preamble": r"\usepackage{amsmath,bm}",
    }
    plt.rcParams.update(params)

    fig, axs = plt.subplots(2, 3, sharex="col", sharey=False, constrained_layout=True)
    fig.set_size_inches(WIDTH, 0.3 * HEIGHT)

    # general formatting
    for ax in axs.flatten():
        ax.set_xscale("log")
        ax.set_yscale("log")
    for ax in axs[-1, :]:
        ax.set_xlabel("Width")
        ax.axhline(0.5, color="k", linestyle="--")
    axs[0, 0].set_ylabel(r"$\lambda_{\text{min}}(\mathbf{G}(0))$")
    axs[1, 0].set_ylabel(r"$C'$ (close-to-linearity)")

    axs[0, 0].set_title("Toy (theory)")
    axs[0, 1].set_title("Toy + depth")
    axs[0, 2].set_title("Real-world (practice)")

    markers = ["o", "X", "*"]
    colors = list(color_palette("colorblind"))

    ###############################################################################
    #                                  Plot data                                  #
    ###############################################################################
    cases = [SYNTHETIC_SHALLOW, SYNTHETIC_DEEP, CIFAR10SUB400_WRN]
    for idx_case, (case, marker, color) in enumerate(zip(cases, markers, colors)):
        min_evals_df = accumulate_gram_matrix_condition(case, ignore_missing=True)

        min_evals_grouped = min_evals_df.groupby("width")
        min_evals_median = min_evals_grouped.median()
        min_evals_lower = min_evals_grouped.quantile(0.25)
        min_evals_upper = min_evals_grouped.quantile(0.75)

        axs[0, idx_case].plot(
            min_evals_median.index,
            min_evals_median["min_eigval"],
            marker=marker,
            color=color,
            markersize=3,
        )
        axs[0, idx_case].fill_between(
            min_evals_median.index,
            min_evals_lower["min_eigval"],
            min_evals_upper["min_eigval"],
            color=color,
            alpha=0.25,
        )

        linearity_df = accumulate_linearity_condition(case, ignore_missing=True)
        linearity_grouped = linearity_df.groupby("width")
        linearity_median = linearity_grouped.median()
        linearity_lower = linearity_grouped.quantile(0.25)
        linearity_upper = linearity_grouped.quantile(0.75)

        axs[1, idx_case].plot(
            linearity_median.index,
            linearity_median["C'"],
            marker=marker,
            color=color,
            markersize=3,
        )
        axs[1, idx_case].fill_between(
            linearity_median.index,
            linearity_lower["C'"],
            linearity_upper["C'"],
            color=color,
            alpha=0.25,
        )

    savepath = path.join(FIGDIR, "convergence_conditions.pdf")
    plt.savefig(savepath)
    plt.close()


def combined_plot():
    """Create and save a combined plot for the convergence conditions."""
    ###############################################################################
    #                                 Plot styling                                #
    ###############################################################################
    # NeurIPS & ICLR
    WIDTH = 5.50107
    HEIGHT = 9.00177
    SCRIPTSIZE = 7  # FOOTNOTESIZE = 9

    params = {
        "text.usetex": False,
        "font.size": SCRIPTSIZE,
        "font.family": "Times New Roman",
        "font.serif": ["Times New Roman"],
        "mathtext.fontset": "cm",
        "axes.linewidth": 0.5,
        "axes.titlesize": SCRIPTSIZE + 1,
        "axes.labelsize": SCRIPTSIZE + 1,
        "lines.linewidth": 1,
        "xtick.major.size": 1.5,
        "ytick.major.size": 1.5,
        "xtick.major.width": 0.5,
        "ytick.major.width": 0.5,
        "legend.fontsize": SCRIPTSIZE - 1,
        "pdf.fonttype": 42,
        "text.usetex": True,
        "text.latex.preamble": r"\usepackage{amsmath,bm}",
    }
    plt.rcParams.update(params)

    fig, axs = plt.subplots(1, 2, constrained_layout=True)
    fig.set_size_inches(WIDTH, 0.2 * HEIGHT)

    # general formatting
    axs[0].set_xscale("log")
    axs[1].set_xscale("log")
    axs[0].set_yscale("log")
    axs[1].set_yscale("log")
    axs[0].set_xlabel("Width")
    axs[1].set_xlabel("Width")
    axs[1].axhline(0.5, color="k", linestyle="--")
    axs[0].set_ylabel(r"$\lambda_{\text{min}}(\mathbf{G}(0))$")
    axs[1].set_ylabel(r"$C'$ (close-to-linearity)")

    labels = ["Toy (theory)", "Toy, $L=3$", "Toy, $L=5$", "Real-world (practice)"]
    markers = ["o", "X", "*", "d"]
    colors = list(color_palette("colorblind"))

    ###############################################################################
    #                                  Plot data                                  #
    ###############################################################################
    cases = [SYNTHETIC_SHALLOW, SYNTHETIC_LESS_DEEP, SYNTHETIC_DEEP, CIFAR10SUB400_WRN]
    for case, color, label, marker in zip(cases, colors, labels, markers):
        min_evals_df = accumulate_gram_matrix_condition(case, ignore_missing=True)

        min_evals_grouped = min_evals_df.groupby("width")
        min_evals_median = min_evals_grouped.median()
        min_evals_lower = min_evals_grouped.quantile(0.25)
        min_evals_upper = min_evals_grouped.quantile(0.75)
        axs[0].plot(
            min_evals_median.index,
            min_evals_median["min_eigval"],
            marker=marker,
            color=color,
            markersize=3,
        )
        axs[0].fill_between(
            min_evals_median.index,
            min_evals_lower["min_eigval"],
            min_evals_upper["min_eigval"],
            color=color,
            alpha=0.25,
        )

        linearity_df = accumulate_linearity_condition(case, ignore_missing=True)
        linearity_grouped = linearity_df.groupby("width")
        linearity_median = linearity_grouped.median()
        linearity_lower = linearity_grouped.quantile(0.25)
        linearity_upper = linearity_grouped.quantile(0.75)

        axs[1].plot(
            linearity_median.index,
            linearity_median["C'"],
            marker=marker,
            color=color,
            markersize=3,
            label=label,
        )
        axs[1].fill_between(
            linearity_median.index,
            linearity_lower["C'"],
            linearity_upper["C'"],
            color=color,
            alpha=0.25,
        )

    plt.legend()

    savepath = path.join(FIGDIR, "convergence_conditions_combined.pdf")
    plt.savefig(savepath)
    plt.close()


def vary_data_plot():
    """Create & save combined plot for the convergence conditions with varying data."""
    ###############################################################################
    #                                 Plot styling                                #
    ###############################################################################
    # NeurIPS & ICLR
    WIDTH = 5.50107
    HEIGHT = 9.00177
    SCRIPTSIZE = 7  # FOOTNOTESIZE = 9

    params = {
        "text.usetex": False,
        "font.size": SCRIPTSIZE,
        "font.family": "Times New Roman",
        "font.serif": ["Times New Roman"],
        "mathtext.fontset": "cm",
        "axes.linewidth": 0.5,
        "axes.titlesize": SCRIPTSIZE + 1,
        "axes.labelsize": SCRIPTSIZE + 1,
        "lines.linewidth": 1,
        "xtick.major.size": 1.5,
        "ytick.major.size": 1.5,
        "xtick.major.width": 0.5,
        "ytick.major.width": 0.5,
        "legend.fontsize": SCRIPTSIZE - 1,
        "pdf.fonttype": 42,
        "text.usetex": True,
        "text.latex.preamble": r"\usepackage{amsmath,bm}",
    }
    plt.rcParams.update(params)

    fig, ax = plt.subplots(1, 1, constrained_layout=True)
    fig.set_size_inches(0.5 * WIDTH, 0.18 * HEIGHT)

    # general formatting
    ax.set_xscale("log")
    ax.set_yscale("log")
    ax.set_xlabel("Width")
    ax.axhline(0.5, color="k", linestyle="--")
    ax.set_ylabel(r"$C'$ (close-to-linearity)")

    labels = [f"$N={case['data_name'].split('_')[-1]}$" for case in CASES_VARY_DATA]
    markers = ["d", "d"]
    colors = list(color_palette("colorblind"))

    ###############################################################################
    #                                  Plot data                                  #
    ###############################################################################
    for case, color, label, marker in zip(CASES_VARY_DATA, colors, labels, markers):
        linearity_df = accumulate_linearity_condition(case, ignore_missing=True)
        linearity_grouped = linearity_df.groupby("width")
        linearity_median = linearity_grouped.median()
        linearity_lower = linearity_grouped.quantile(0.25)
        linearity_upper = linearity_grouped.quantile(0.75)
        ax.plot(
            linearity_median.index,
            linearity_median["C'"],
            marker=marker,
            color=color,
            markersize=3,
            label=label,
        )
        ax.fill_between(
            linearity_median.index,
            linearity_lower["C'"],
            linearity_upper["C'"],
            color=color,
            alpha=0.25,
        )

    plt.legend()

    savepath = path.join(FIGDIR, "convergence_conditions_vary_data.pdf")
    plt.savefig(savepath)
    plt.close()


if __name__ == "__main__":
    group_plot()
    combined_plot()
    vary_data_plot()
