import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

from spectrum import plotting
from spectrum.plotting import fmt_pow10


def load_data(alphas=None):

    alpha = 0.5

    def sign_descent(T, eta, d):
        d0s = np.array([1 / k**alpha for k in range(1, d + 1)])
        results = np.zeros((T, d))
        results[0, :] = d0s
        for t in range(1, T):
            results[t, :] = results[t - 1, :] - eta * np.sign(results[t - 1, :])
        return results

    def sign_approx(T, eta, d):
        d0s = np.array([1 / k**alpha for k in range(1, d + 1)])
        results = np.zeros((T, d))
        results[0, :] = d0s
        for t in range(T - 1):
            next = results[t, :] - eta * np.sign(results[t, :])
            next[next < 0] = eta / 2
            results[t + 1, :] = next
        return results

    d = 4
    eta = 0.291564815318473456
    T = 16
    approx = sign_approx(T, eta, d)
    exact = sign_descent(T, eta, d)

    d_big = 1000
    eta_big = 0.00891564815318473456
    T_big = 301
    approx_big = sign_approx(T_big, eta_big, d_big)
    exact_big = sign_descent(T_big, eta_big, d_big)

    stepsizes = [2**-i for i in [6, 7, 8, 9, 10]]
    T_big = 1500
    exact_stepsizes = [
        sign_approx(T_big, stepsize, d_big) for stepsize in tqdm(stepsizes)
    ]

    return {
        "alpha": alpha,
        "d": d,
        "T": T,
        "exact": exact,
        "approx": approx,
        "d_big": d_big,
        "T_big": T_big,
        "exact_big": exact_big,
        "approx_big": approx_big,
        "exact_stepsizes": exact_stepsizes,
        "stepsizes": stepsizes,
    }


def postprocess(data):
    return data


def settings(plt):
    plotting.update_style(plt, ncols=3)  # , height_to_width_ratio=1 / 1.3)


def make_figure(fig, data, logy=True):

    axes = [
        fig.add_subplot(131),
        fig.add_subplot(132),
        fig.add_subplot(133),
    ]

    exact, approx = data["exact"], data["approx"]
    d, T = data["d"], data["T"]

    camp = plt.get_cmap("YlOrBr")
    colors = [camp(i) for i in np.linspace(0.4, 0.8, d)]

    LW = 3
    line_hs = []
    for i in reversed(range(d)):
        (line_h,) = axes[0].plot(
            np.arange(T),
            np.abs(exact[:, i]),
            color=colors[i],
            linewidth=LW,
            label=f"$\\delta_{{{i+1}}}$",
        )
        line_hs.append(line_h)
    for i in reversed(range(d)):
        if i == 0:
            (line_h,) = axes[0].plot(
                np.arange(T),
                np.abs(approx[:, i]),
                linestyle="-",
                color="k",
                linewidth=LW / 3,
                label="Assumption",
            )
        else:
            axes[0].plot(
                np.arange(T),
                np.abs(approx[:, i]),
                linestyle="-",
                color="k",
                linewidth=LW / 3,
            )

    alpha = data["alpha"]
    N = (np.array([k**-alpha for k in range(1, d + 1)]) ** 2).sum()

    axes[1].plot(
        np.arange(T),
        np.sum(exact**2, axis=1) / N,
        color=colors[int(d / 2)],
        linewidth=LW,
        label="Exact",
    )
    axes[1].plot(
        np.arange(T),
        np.sum(approx**2, axis=1) / N,
        color="black",
        linewidth=LW / 3,
        label="Assumption",
    )

    legend = axes[0].legend(
        list(reversed(line_hs)),
        list(reversed([_.get_label() for _ in line_hs])),
        ncol=2,
        loc="upper right",
        frameon=False,
        borderpad=0.3,
        borderaxespad=0.3,
        handlelength=1.5,
        handletextpad=0.4,
        labelspacing=0.2,
        fontsize=9,
        columnspacing=1.0,
    )
    axes[0].add_artist(legend)

    axes[0].legend(
        [line_h],
        [line_h.get_label()],
        ncol=2,
        loc="best",
        bbox_to_anchor=(0.0, 0.0, 1.0, 0.65),
        frameon=False,
        borderpad=0.3,
        borderaxespad=0.3,
        handlelength=1.5,
        handletextpad=0.4,
        labelspacing=0.3,
        fontsize=9,
    )
    axes[1].legend(
        loc="best",
        frameon=False,
        borderpad=0.3,
        borderaxespad=0.3,
        handlelength=1.5,
        handletextpad=0.4,
        labelspacing=0.3,
        fontsize=9,
    )

    axes[0].set_title("Individual directions")
    axes[0].set_ylabel("$\\vert\\delta_i(t)\\vert$", labelpad=-5)
    axes[0].set_xlabel("$t$", labelpad=-5)
    axes[1].set_title("Total error ($d=4$)")
    axes[1].set_ylabel("Relative error", labelpad=-5)
    axes[1].set_xlabel("$t$", labelpad=-5)

    stepsizes = data["stepsizes"]
    exact_stepsizes = data["exact_stepsizes"]
    T_big = data["T_big"]
    d_big = data["d_big"]
    N = (np.array([k**-alpha for k in range(1, d_big + 1)]) ** 2).sum()

    def fmt(x):
        pow = np.log2(x)
        return "2^{" + f"{pow:.0f}" + "}"

    cmap = plt.get_cmap("YlOrBr")
    colors = [cmap(i) for i in np.linspace(0.4, 0.8, len(stepsizes))]

    for i, stepsize in enumerate(stepsizes):
        axes[2].plot(
            np.arange(T_big),
            np.sum(exact_stepsizes[i] ** 2, axis=1) / N,
            label=f"$\\eta = {fmt(stepsize)}$",
            color=colors[i],
        )
    # axes[2].set_xscale("log")
    axes[2].set_yscale("log")
    TS = np.arange(1, T_big + 1)
    ys = 30 / TS**2
    ys[ys > 1] = 1
    axes[2].plot(
        TS,
        ys,
        linestyle="--",
        dashes=(3, 2),
        color="k",
        # label="Envelope",
    )
    axes[2].set_yticks([10**-i for i in [5, 4, 3, 2, 1, 0]])
    axes[2].set_yticklabels(map(fmt_pow10, [10**-5, None, None, None, None, 1]))
    axes[2].set_title("Budget $T$ vs. step-size $\eta$")
    axes[2].set_ylabel("Relative error", labelpad=-10)
    axes[2].legend(
        loc="upper right",
        frameon=True,
        fancybox=False,
        edgecolor="0.5",
        framealpha=1.0,
        borderpad=0.3,
        borderaxespad=0.0,
        handlelength=1.5,
        handletextpad=0.4,
        labelspacing=0.1,
        fontsize=7,
    )
    axes[2].set_xlabel("$t$", labelpad=-5)

    if False:
        T_big, d_big, exact_big, approx_big = (
            data["T_big"],
            data["d_big"],
            data["exact_big"],
            data["approx_big"],
        )
        N = (np.array([k**-alpha for k in range(1, d_big + 1)]) ** 2).sum()
        axes[2].plot(
            np.arange(T_big),
            np.sum(exact_big**2, axis=1) / N,
            color=colors[int(d / 2)],
            linewidth=LW,
            label="Exact",
        )
        axes[2].plot(
            np.arange(T_big),
            np.sum(approx_big**2, axis=1) / N,
            color="black",
            linewidth=LW / 3,
            label="Assumption",
        )

        axes[2].legend(
            loc="best",
            frameon=False,
            borderpad=0.3,
            borderaxespad=0.3,
            handlelength=1.5,
            handletextpad=0.4,
            labelspacing=0.3,
            fontsize=9,
        )
        axes[2].set_title("Total error ($d=100$)")
        axes[2].set_ylabel("Relative$ error", labelpad=-5)
        axes[2].set_xlabel("$t$", labelpad=-5)

    if True:
        axes[0].set_ylim([0, 1])
        # axes[2].set_ylim([0, 1])
        for ax in axes[:-1]:
            ax.set_ylim([0, 1.0])
            ax.set_yticks([0, 0.5, 1.0])
            ax.set_yticklabels([0, "", 1.0])

    fig.tight_layout(pad=0.1)


if __name__ == "__main__":
    data = load_data()
    data = postprocess(data)
    settings(plt)
    fig = plt.figure()
    make_figure(fig, data)
    plt.savefig(f"figs/sign-assumption.pdf")
