import pickle
from functools import lru_cache
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
from scipy.special import expn, gamma, gammaincc
from tqdm import tqdm

from spectrum import plotting
from spectrum.plotting import fmt_pow10


def load_data(alphas=None):

    if alphas is None:
        alphas = [0.5, 1.0, 2.0]
    cs = [10**x for x in [2, 4, 6, 8]]

    @lru_cache
    def compute_unnorm_pis(d, alpha):
        return np.array([1 / k**alpha for k in range(1, d + 1)])

    @lru_cache
    def compute_pis(d, alpha):
        pis = compute_unnorm_pis(d, alpha)
        return pis / np.sum(pis)

    @lru_cache
    def compute_normalization(d, alpha):
        return np.sum(compute_unnorm_pis(d, alpha))

    class CachedGDResults:
        def make_filename(self, t, c, alpha):
            t, c, alpha = map(lambda _: f"{_:.5g}", (t, c, alpha))
            folder = Path(f"cache/gd/")
            folder.mkdir(parents=True, exist_ok=True)
            return folder / f"results_t={t}_c={c}_alpha={alpha}.pk"

        def is_saved(self, t, c, alpha):
            return self.make_filename(t, c, alpha).exists()

        def load(self, t, c, alpha):
            filename = self.make_filename(t, c, alpha)
            with open(filename, "rb") as f:
                return pickle.load(f)

        def save(self, t, c, alpha, val):
            filename = self.make_filename(t, c, alpha)
            with open(filename, "wb") as f:
                pickle.dump(val, f)

        def compute_loss_at_t(self, t, c, alpha):
            if self.is_saved(t, c, alpha):
                return self.load(t, c, alpha)

            upis = compute_unnorm_pis(c, alpha)
            N = compute_normalization(c, alpha)
            val = np.inner(upis, (1 - upis) ** (2 * t)) / N
            self.save(t, c, alpha, val)
            return val

    def define_time(c):
        return np.logspace(0, 3 + np.log10(c), base=10, num=50)

    cache = CachedGDResults()
    results = {
        alpha: {
            c: [
                cache.compute_loss_at_t(t, c, alpha)
                for t in tqdm(define_time(c), desc="t", leave=False)
            ]
            for c in tqdm(cs, desc="c", leave=False)
        }
        for alpha in tqdm(alphas, desc="alpha")
    }
    return {
        "results": results,
        "alphas": alphas,
        "cs": cs,
        "define_time": define_time,
    }


def postprocess(data):
    return data


def settings(plt):
    plotting.update_style(plt, ncols=3)  # , height_to_width_ratio=1 / 1.618)


def make_figure(fig, data, logy=True):

    results = data["results"]
    alphas = data["alphas"]
    cs = data["cs"]

    ROW = 1
    COL = len(alphas)
    axes = [
        [fig.add_subplot(ROW, COL, COL * r + c + 1) for c in range(COL)]
        for r in range(ROW)
    ]

    def format_as_power_of_10(x):
        if x == 1:
            return "1"
        if x == 10:
            return "10"
        if x > 10:
            return f"10^{int(np.log10(x))}"
        return f"{x:.2g}"

    def E(n, x):
        return x ** (n - 1) * gamma(1 - n) * gammaincc(1 - n, x)

    def rescale_time(ts, alpha, c):
        if alpha < 1:
            return ts / c**alpha
        if alpha == 1:
            return np.log(ts) / np.log(c)
        if alpha > 1:
            return ts

    def theory(taus, alpha):
        if alpha < 1:
            const = (1 - alpha) / alpha
            return const * np.exp(-2 * taus) / (1 + 2 * taus)
            return const * expn(1 / alpha, 2 * taus)
        if alpha == 1:
            return 1 - taus
        if alpha > 1:
            C = sp.special.gamma(1 - 1 / alpha) / (sp.special.zeta(alpha) * alpha)
            return C / (2 * taus) ** (1 - 1 / alpha)

    def xlabel(alpha):
        if alpha < 1:
            return "$\\tau = T/d^\\alpha$"
        if alpha == 1:
            return "$\\tau : t = d^\\tau $"
        if alpha > 1:
            return "$\\tau = t$"

    def xticks(alpha):
        if alpha < 1:
            return [10**-2, 10**-1, 10**0, 10**1]
        if alpha == 1:
            return [10**-2, 10**-1, 10**0, 10**1]
        if alpha > 1:
            return [10**0, 10**2, 10**4, 10**6, 10**8]

    def xticklabels(alpha):
        if alpha < 1:
            ticks = [10**-2, None, None, 10**1]
        if alpha == 1:
            ticks = [10**-2, None, None, 10**1]
        if alpha > 1:
            ticks = [10**-0, 10**2, None, 10**6, 10**8]
        return map(fmt_pow10, ticks)

    def ylabel(alpha):
        return "Relative error"

    def ylim(alpha):
        return (10**-4, 10**0.5)

    def xlim(alpha):
        if alpha < 1:
            return (10**-2, 10**1)
        if alpha == 1:
            return (10**-2, 10**1)
        if alpha > 1:
            return (10**0, 10**8)

    def prediction_label(alpha):
        if alpha < 1:
            return "$\\frac{1-\\alpha}{\\alpha}\\frac{e^{-\\tau}}{1+\\tau}$"
            # return "$\\frac{1-\\alpha}{\\alpha}E_{\\frac{1}{\\alpha}}(\\tau)$"
        if alpha == 1:
            return "$1-\\tau$"
        if alpha > 1:
            return "$\\frac{C}{\\tau^{1-\\frac{1}{\\alpha}}}$"

    cmap = plt.get_cmap("YlOrBr")
    colors = [cmap(i) for i in np.linspace(0.3, 0.7, len(cs))]

    for i, alpha in enumerate(alphas):
        ax = axes[0][i]
        for j, c in enumerate(cs):
            ts = data["define_time"](c)
            rescaled_time = rescale_time(ts, alpha, c)
            ax.plot(
                rescaled_time,
                results[alpha][c],
                label=f"$d={format_as_power_of_10(c)}$",
                color=colors[j],
                alpha=1.0,
            )

        ax.set_title(f"$\\alpha = {alpha:.2g}$", y=0.95)
        ax.set_xscale("log")
        ax.set_yscale("log")
        ax.set_ylabel(ylabel(alpha))
        ax.set_xlabel(xlabel(alpha), labelpad=-5)
        ax.set_ylim(ylim(alpha))
        ax.set_xlim(xlim(alpha))
        ax.set_xticks(xticks(alpha))
        ax.set_xticklabels(xticklabels(alpha))

    color_legend = plt.legend(
        loc="upper right",
        frameon=False,
        borderpad=0.3,
        borderaxespad=0.3,
        handlelength=1.5,
        handletextpad=0.4,
        labelspacing=0.2,
        fontsize=8,
    )

    LW = 2

    for i, alpha in enumerate(alphas):
        ax = axes[0][i]
        rescaled_time = np.logspace(*map(np.log10, xlim(alpha)), base=10, num=1000)
        (line_h,) = ax.plot(
            rescaled_time,
            theory(rescaled_time, alpha),
            "--",
            color="k",
            linewidth=LW,
            dashes=(3, 2),
            label=prediction_label(alpha),
        )

        ax.legend(
            [line_h],
            [prediction_label(alpha)],
            loc="lower left",
            frameon=False,
            markerfirst=False,
            borderpad=0.25,
            handlelength=2.0,
            borderaxespad=0.25,
            handletextpad=0.4,
            fontsize=9,
        )
    axes[0][-1].add_artist(color_legend)

    axes[0][1].set_ylabel("")
    axes[0][2].set_ylabel("")
    axes[0][1].set_yticklabels([])
    axes[0][2].set_yticklabels([])

    fig.tight_layout(pad=0)


if __name__ == "__main__":

    data = load_data()
    data = postprocess(data)
    settings(plt)
    fig = plt.figure()
    make_figure(fig, data)
    plt.savefig(f"figs/gd-losses.pdf")
