import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

from spectrum import plotting


def load_data():

    def loss_over_time(T, d, alpha, beta):
        errors = np.array([1 / k**beta for k in range(1, d + 1)])
        eigs = np.array([1 / k**alpha for k in range(1, d + 1)])
        eigs = eigs / np.sum(eigs)
        losses = [np.inner(eigs, errors)]
        stepsize = 1 / np.max(eigs)
        for _ in tqdm(range(T), total=T):
            errors = (1 - stepsize * eigs) ** 2 * errors
            losses.append(np.inner(eigs, errors))
        return losses

    def sign_descent(T, d, alpha, beta):
        errors = np.array([1 / k**beta for k in range(1, d + 1)])
        eigs = np.array([1 / k**alpha for k in range(1, d + 1)])
        eigs = eigs / np.sum(eigs)
        losses = [np.inner(eigs, errors)]
        stepsize = 0.01 / np.sum(eigs)
        for _ in tqdm(range(T), total=T):
            errors = errors - stepsize * np.sum(errors) * np.sign(errors)
            losses.append(np.inner(eigs, errors))
        return losses

    T = 1000
    ds = [
        # 10,
        100,
        1_000,
        10_000,
        100_000,
        1_000_000,
    ]
    beta = 0

    alphas = [0.5, 1.0, 2.0]
    losses = {
        d: {alpha: loss_over_time(T, d, alpha, beta) for alpha in alphas} for d in ds
    }
    sign_losses = {
        # d: {alpha: sign_descent(T, d, alpha, beta) for alpha in alphas} for d in ds
    }

    return {
        "losses": losses,
        "sign_losses": sign_losses,
    }


def postprocess(data):
    return data


def settings(plt):
    plotting.update_style(
        plt, ncols=3, rel_width=1
    )  # , height_to_width_ratio=1 / 1.618)


def make_figure(fig, data):
    losses = data["losses"]
    ds = sorted(losses.keys())
    alphas = sorted(losses[ds[0]].keys())

    all_losses_gd = data["losses"]
    # all_losses_signd = data["sign_losses"]

    axes = [fig.add_subplot(1, len(alphas), 1 + i) for i in range(len(alphas))]

    def format_10p(x):
        power = int(np.log10(x))
        if power == 0:
            return "1^{\hphantom{0}} "
        if power == 1:
            return "10^{\,} \,\,"
        return "10^{" + f"{power}" + "}"

    lines_to_legend = []

    colormap = plt.get_cmap("YlOrBr")
    colors = list(([colormap(i) for i in np.linspace(0.3, 0.8, len(ds))]))

    def make_label(d):
        if d == 10**6:
            return f"$d = {format_10p(d)}$"
        return f"${format_10p(d)}$"

    for i, d in enumerate((ds)):
        for j, alpha in enumerate(alphas):
            ax = axes[j]

            losses_gd = all_losses_gd[d][alpha]
            # losses_signd = all_losses_signd[d][alpha]

            (line,) = ax.plot(
                losses_gd,
                label=make_label(d),
                color=colors[i],
                linewidth=2,
            )
            if j == 2:
                lines_to_legend.append(line)

    for ax in axes:
        ax.axhline(0, linestyle="-", color="k", alpha=0.5)

    axes[0].set_title(
        r"$\pi_k \propto \frac{1}{k^{1/2}}$",
        y=0.9,
    )
    axes[1].set_title(
        r"$\pi_k \propto \frac{1}{k}$",
        y=0.9,
    )
    axes[2].set_title(
        r"$\pi_k \propto \frac{1}{k^{2}}$",
        y=0.9,
    )

    for ax in axes:
        pass

    use_log_y = False
    use_log_x = False

    axes[0].set_ylabel("Relative error")
    for ax in axes:
        # ax.set_ylim([-0.1, 1.1])
        if use_log_y:
            ax.set_yscale("log")
            ax.set_ylim([10**-2, 10**0.1])
        else:
            ax.set_ylim([-0.0, 1.0])
            ax.set_yticks([0, 0.5, 1])

        # ax.set_yscale("log")
        # ax.set_ylim([1e-2, 1.1 * 1e0])

        if use_log_x:
            ax.set_xscale("log")
        else:
            ax.set_xticks([0, 250, 500, 750, 1000])
            ax.set_xticklabels([0, 250, "", 750, 1000])
            ax.set_xlabel("Steps", labelpad=-5)
        # ax.set_yscale("log")

    lines_to_legend = [_ for _ in reversed(lines_to_legend)]

    axes[2].legend(
        lines_to_legend,
        map(lambda x: x.get_label(), lines_to_legend),
        loc="lower right",
        frameon=False,
        markerfirst=False,
        borderpad=0.25,
        borderaxespad=0.25,
        handlelength=1.5,
        handletextpad=0.4,
        labelspacing=0.2,
        fontsize=8,
    )

    axes[1].set_yticklabels([])
    axes[2].set_yticklabels([])

    fig.tight_layout(pad=0.2)


if __name__ == "__main__":
    settings(plt)
    fig = plt.figure()
    data = load_data()
    make_figure(fig, data)
    plt.savefig("figs/bad-scaling.pdf")
    plt.close(fig)
