"""
Final plots for the scaling of sign descent 
"""

import matplotlib.pyplot as plt
import numpy as np

from spectrum import sign
from spectrum.cache import ResultsCache
from spectrum.plotting import fmt_pow10, hide_frame, normalize_y_axis, update_style
from spectrum.utils import nice_logspace


def load_data(alphas=None):

    if alphas is None:
        # alphas = [0.1, 0.25, 0.5, 0.75, 1.0, 1.5, 2.0, 10.0]
        alphas = [0.25, 0.5, 1.0]
        # alphas = [1.0]

    ts = nice_logspace(0, 5, density=2)
    # ds = nice_logspace(1, 7, base=10, density=0).astype(int)
    ds = np.array([10**i for i in [2, 4, 6, 8]]).astype(int)
    logd_phis = nice_logspace(-10, 0, base=10, density=5)

    cache = ResultsCache().load()

    try:
        results_array = cache.load_data_as_array(alphas, ts, ds, logd_phis)
    except KeyError:
        cache.update_cache(alphas, ts, ds, logd_phis)
        results_array = cache.load_data_as_array(alphas, ts, ds, logd_phis)

    return {
        "alphas": alphas,
        "ds": ds.astype(float),
        "ts": ts.astype(float),
        "logd_phis": logd_phis,
        "results_array": results_array,
    }


def postprocess(data):
    return data


def settings(plt):
    update_style(plt, ncols=3)  # , height_to_width_ratio=1 / 1.3)


def make_figure(fig, data, make_losses=True):
    alphas = data["alphas"]
    ds = data["ds"]
    ts = data["ts"]

    COLS = len(alphas)
    axes = [[fig.add_subplot(1, COLS, i + 1) for i in range(COLS)]]

    cmap = plt.get_cmap("YlOrBr")
    colors = [cmap(i) for i in np.linspace(0.3, 0.8, len(ds))]

    #

    def format_pow_10(power_of_10):
        power = np.log10(power_of_10)
        return f"10^{{{power:g}}}"

    def flatten(xs):
        return [item for sublist in xs for item in sublist]

    LW = 2

    def ylabel_phi(alpha):
        if alpha < 1 / 2:
            return "$\\phi/d$"
        if alpha == 1 / 2:
            return "$x : \\phi = d^x$"
        if alpha > 1 / 2:
            return "$\\phi$"

    def ylabel_loss(alpha):
        return "Relative error"

    def xscale_log(alpha):
        if alpha < 1 / 2:
            return True
        if alpha == 1 / 2:
            return True
        if alpha > 1 / 2:
            return True

    def yscale_phi_log(alpha):
        if alpha < 1 / 2:
            return True
        if alpha == 1 / 2:
            return False
        if alpha > 1 / 2:
            return True

    def ylim_phi(alpha):
        if alpha < 1 / 2:
            return (10**-8, 10**0.5)
        if alpha == 1 / 2:
            return (-0.1, 1.1)
        if alpha > 1 / 2:
            return (10**-0.5, 10**4)

    def ylim_losses(alpha):
        return (10**-4, 10**0.5)

    def xlabel(alpha):
        return sign.tau_rescaling_label(alpha)

    def xlim(alpha):
        if alpha < 1 / 2:
            return (10**0, 10**4)
        if alpha == 1 / 2:
            return (10**-1, 10)
        if alpha > 1 / 2:
            return (10**-2, 10**2.2)

    def xticks(alpha):
        if alpha < 1 / 2:
            return [1, 10**1, 10**2, 10**3, 10**4]
        if alpha == 1 / 2:
            return [10**-1, 1, 10**1]
        if alpha > 1 / 2:
            return [10**-2, 10**-1, 10**0, 10**1, 10**2]

    def xticklabels(alpha):
        if alpha < 1 / 2:
            ticks = [1, "", "", "", 10**4]
        if alpha == 1 / 2:
            ticks = [10**-1, None, 10**1]
        if alpha > 1 / 2:
            ticks = [10**-2, None, None, None, 10**2]
        return map(fmt_pow10, ticks)

    def plot_stepsizes():
        for i, alpha in enumerate(alphas):
            ax = axes[0][i]
            for j, d in enumerate(ds):
                results = data["results_array"][i, j]
                logd_phis = data["logd_phis"]

                phis = d**logd_phis
                np.min(results, axis=1)
                emp_losses = np.min(results, axis=1)
                emp_phis = phis[np.argmin(results, axis=1)]

                ax.plot(
                    sign.rescaled_time(ts, d, alpha),
                    sign.rescaled_phi(emp_phis, d, alpha),
                    color=colors[j],
                    label=f"$d={format_pow_10(d)}$",
                    linewidth=LW,
                    alpha=1.0,
                )

            ax.set_title(f"$\\alpha={alpha:.3g}$", y=0.95)
            ax.set_ylabel(ylabel_phi(alpha))
            if xscale_log(alpha):
                ax.set_xscale("log")
            if yscale_phi_log(alpha):
                ax.set_yscale("log")
            ax.set_xlim(*xlim(alpha))
            ax.set_ylim(*ylim_phi(alpha))
            ax.set_xlabel(xlabel(alpha), labelpad=-5)
            ax.set_xticks(xticks(alpha))
            ax.set_xticklabels(xticklabels(alpha))

        normalize_y_axis(*[row[0] for row in axes])

        hide_frame(*flatten(axes))

        color_legend = plt.legend(
            loc="upper right",
            frameon=False,
            borderpad=0.5,
            borderaxespad=0.5,
            handlelength=1.5,
            handletextpad=0.4,
            labelspacing=0.2,
            fontsize=8,
        )

        for i, alpha in enumerate(alphas):
            if xscale_log(alpha):
                taus = np.logspace(*map(np.log10, xlim(alpha)), base=10, num=200)
            else:
                taus = np.linspace(*xlim(alpha), num=100)

            ax = axes[0][i]
            (line_h,) = ax.plot(
                taus,
                sign.predicted_phi(taus, alpha),
                color="k",
                linestyle="--",
                # linestyle = "dashed",
                dashes=(3, 2),
                linewidth=LW,
            )

            ax.legend(
                [line_h],
                [sign.phi_label(alpha)],
                loc="lower left",
                frameon=False,
                markerfirst=False,
                borderpad=0.25,
                handlelength=2.0,
                borderaxespad=0.25,
                handletextpad=0.4,
                fontsize=9,
            )
        axes[0][-1].add_artist(color_legend)

    def plot_losses():
        for i, alpha in enumerate(alphas):
            ax = axes[0][i]
            for j, d in enumerate(ds):
                results = data["results_array"][i, j]
                logd_phis = data["logd_phis"]

                phis = d**logd_phis
                np.min(results, axis=1)
                emp_losses = np.min(results, axis=1)
                emp_phis = phis[np.argmin(results, axis=1)]

                ax.plot(
                    sign.rescaled_time(ts, d, alpha),
                    emp_losses,
                    color=colors[j],
                    label=f"$d={format_pow_10(d)}$",
                    linewidth=LW,
                    alpha=1.0,
                )

            ax.set_title(f"$\\alpha={alpha:.3g}$", y=0.95)
            if xscale_log(alpha):
                ax.set_xscale("log")
            ax.set_ylabel(ylabel_loss(alpha))
            ax.set_yscale("log")
            ax.set_xlim(*xlim(alpha))
            ax.set_ylim(*ylim_losses(alpha))
            ax.set_xlabel(xlabel(alpha), labelpad=-5)
            ax.set_xticks(xticks(alpha))
            ax.set_xticklabels(xticklabels(alpha))

        normalize_y_axis(*[row[0] for row in axes])

        hide_frame(*flatten(axes))

        color_legend = plt.legend(
            loc="lower left",
            frameon=False,
            borderpad=0.3,
            borderaxespad=0.3,
            handlelength=1.5,
            handletextpad=0.4,
            labelspacing=0.2,
            fontsize=8,
        )
        axes[0][-1].add_artist(color_legend)

        for i, alpha in enumerate(alphas):
            if xscale_log(alpha):
                taus = np.logspace(*map(np.log10, xlim(alpha)), base=10, num=100)
            else:
                taus = np.linspace(*xlim(alpha), num=100)
                taus = np.concatenate([taus, np.logspace(-0.01, 0, num=100)])
                taus = np.sort(taus)

            ax = axes[0][i]
            (line_h,) = ax.plot(
                taus,
                sign.predicted_loss(taus, alpha),
                color="k",
                linestyle="--",
                linewidth=LW,
                dashes=(3, 2),
            )
            ax.legend(
                [line_h],
                [sign.loss_label(alpha)],
                bbox_to_anchor=(0.5, 0.5, 0.55, 0.5),
                loc="upper right",
                frameon=False,
                markerfirst=True,
                borderpad=0.175,
                borderaxespad=0.175,
                handlelength=2.0,
                handletextpad=0.2,
                fontsize=9,
            )

        axes[0][1].set_ylabel("")
        axes[0][2].set_ylabel("")
        axes[0][1].set_yticklabels([])
        axes[0][2].set_yticklabels([])

    if make_losses:
        plot_losses()
    else:
        plot_stepsizes()

    fig.tight_layout(pad=0)

    return fig


if __name__ == "__main__":
    alphas = [0.25, 0.5, 1.0]
    data = load_data(alphas)
    data = postprocess(data)
    settings(plt)

    fig = plt.figure()
    make_figure(fig, data, make_losses=True)
    plt.savefig(f"figs/sign-losses.pdf", dpi=300)
    plt.close(fig)

    fig = plt.figure()
    make_figure(fig, data, make_losses=False)
    plt.savefig(f"figs/sign-stepsizes.pdf", dpi=300)
    plt.close(fig)
