"""
Appendix figues
"""

import matplotlib.pyplot as plt
import numpy as np

from spectrum.plotting import fmt_pow10, update_style
from spectrum.utils import nice_logspace


def load_data(alphas=None):

    def rate_linear(d, ts):
        return (1 - 1 / d) ** ts

    def rate_sublinear(d, ts):
        return 2 * (d / np.log(d)) / ts

    ds = np.array([10**i for i in [2, 4, 8, 16]])
    ts = nice_logspace(0, 2 + int(np.log10(np.max(ds))), density=4)

    results = {d: {} for d in ds}
    for d in ds:
        results[d]["lin"] = rate_linear(d, ts)
        results[d]["sub"] = rate_sublinear(d, ts)

    return {
        "ds": ds.astype(float),
        "ts": ts.astype(float),
        "results": results,
    }


def postprocess(data):
    return data


def settings(plt):
    update_style(plt, ncols=2)  # , height_to_width_ratio=1 / 1.3)


def make_figure(fig, data):

    axes = [
        fig.add_subplot(121),
        fig.add_subplot(122),
    ]

    results = data["results"]

    ds = data["ds"]
    ts = data["ts"]

    def to_tau(ts, d):
        return np.log(ts) / np.log(d)

    cmap = plt.get_cmap("YlOrBr")
    colors = [cmap(i) for i in np.linspace(0.3, 0.8, len(ds))]

    lines_sub = []
    lines_lin = []
    for i, d in enumerate(ds):
        for j in [0, 1]:
            (line,) = axes[j].plot(
                to_tau(ts, d),
                results[d]["lin"],
                linestyle="-",
                linewidth=2,
                color=colors[i],
            )
            lines_lin.append(line)
            (line,) = axes[j].plot(
                to_tau(ts, d),
                results[d]["sub"],
                linestyle=(0, (1, 1)),
                linewidth=2,
                color=colors[i],
            )
            lines_sub.append(line)

    taus = np.unique(
        np.concatenate([np.linspace(0, 1, 100), np.logspace(-0.01, 0, 200)])
    )

    lines_ref = []
    for j in [0, 1]:
        (line,) = axes[j].plot(
            taus,
            1 - taus,
            linestyle="--",
            dashes=(3, 2),
            linewidth=1,
            color="k",
        )
        lines_ref.append(line)

    axes[0].legend(
        [lines_ref[0]],
        [r"$1 - \tau$"],
        loc="lower left",
        frameon=False,
        borderpad=0.5,
        borderaxespad=0.5,
        handlelength=1.5,
        handletextpad=0.4,
        labelspacing=0.2,
        fontsize=8,
    )

    legend = axes[1].legend(
        lines_lin,
        [f"$d = \,${fmt_pow10(d)}" for d in ds],
        loc="lower left",
        frameon=False,
        borderpad=0.5,
        borderaxespad=0.5,
        handlelength=1.5,
        handletextpad=0.4,
        labelspacing=0.2,
        fontsize=8,
        title=r"Linear",
    )
    axes[1].legend(
        lines_sub,
        [f"$d = \,${fmt_pow10(d)}" for d in ds],
        frameon=False,
        loc="lower center",
        borderpad=0.5,
        borderaxespad=0.5,
        handlelength=1.5,
        handletextpad=0.4,
        labelspacing=0.2,
        fontsize=8,
        title=r"Sublinear",
    )
    axes[1].add_artist(legend)

    axes[0].set_ylim([0, 1.15])
    axes[0].set_xlim([0, 1.15])

    axes[1].set_yscale("log")
    axes[1].set_ylim([10**-3, 10**0.5])
    axes[1].set_xlim([0, 1.3])

    axes[0].set_yticks([0, 0.5, 1])
    axes[0].set_yticklabels([0, "", 1])
    axes[1].set_yticks([10**-3, 10**-2, 10**-1, 1])
    axes[1].set_yticklabels(map(fmt_pow10, [10**-3, "", "", 1]))

    axes[0].set_xticks([0, 0.25, 0.5, 0.75, 1])
    axes[0].set_xticklabels([0, "", "", "", 1])
    axes[1].set_xticks([0, 0.25, 0.5, 0.75, 1])
    axes[1].set_xticklabels([0, "", "", "", 1])

    for j in [0, 1]:
        axes[j].set_xlabel(
            r"$\tau = \frac{\log(t)}{\log(d)}$",
            labelpad=-5,
        )
        axes[j].set_ylabel(r"Relative error", labelpad=-5)
        axes[j].set_title(r"Standard rates")
    axes[j].set_title(r"Standard rates (log scale)")

    fig.tight_layout(pad=0.1)

    return fig


if __name__ == "__main__":
    data = load_data()
    data = postprocess(data)
    settings(plt)
    fig = plt.figure()
    make_figure(fig, data)
    plt.savefig(f"figs/compare-rates.pdf", dpi=300)
    plt.close(fig)
