import pickle
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import scipy as sp
from tqdm import tqdm

from spectrum import plotting


def load_data(alphas=None, vocab_size=None):

    if vocab_size is None:
        vocab_size = 3_162
        vocab_size = 1_000
        vocab_size = 10_000
        vocab_size = 31_622

    vocab_sizes = [
        1_000,
        3_162,
        10_000,
        31_622,
    ]
    times = {}
    results_gd = {}
    results_sd = {}
    preds_gd = {}
    best_ss = {}
    preds_sd = {}

    for i, d in enumerate(vocab_sizes):

        freqs = np.load(f"freqs/token_freq_total_{d}.npy")
        cond_freqs = np.load(f"freqs/bigram_freq_total_{d}.npy")

        freqs = freqs.astype(np.float64)
        cond_freqs = cond_freqs.astype(np.float64)
        freqs /= np.sum(freqs)
        for i in range(d):
            tot = np.sum(cond_freqs[i, :])
            if tot > 0:
                cond_freqs[i, :] = cond_freqs[i, :] / tot

        max_eig = np.max(freqs)

        d0s = freqs
        d0s2 = d0s**2
        eigs = np.tile(freqs[:, None], (1, d))
        loss0 = np.sum(eigs * d0s2)

        class Cache:

            def __init__(self, vocab_size):
                self.d = vocab_size

            def load(self):
                file = Path(f"cache/real_data/{self.d}/") / "cache.pk"
                if file.exists():
                    with open(file, "rb") as f:
                        self.data = pickle.load(f)
                else:
                    self.data = {}
                return self

            def save(self):
                dir = Path(f"cache/real_data/{self.d}")
                dir.mkdir(parents=True, exist_ok=True)
                file = dir / "cache.pk"
                with open(file, "wb") as f:
                    pickle.dump(self.data, f)

            def put(self, name, d, t, eta, val):
                name, d, t, eta = (
                    self.fmt(name),
                    self.fmt(d),
                    self.fmt(t),
                    self.fmt(eta),
                )

                if name not in self.data:
                    self.data[name] = {}
                if d not in self.data[name]:
                    self.data[name][d] = {}
                if t not in self.data[name][d]:
                    self.data[name][d][t] = {}
                if eta not in self.data[name][d][t]:
                    self.data[name][d][t][eta] = {}
                self.data[name][d][t][eta] = val

            def fmt(self, x):
                if isinstance(x, str):
                    return x
                return f"{x:.4g}"

            def get(self, name, d, t, eta):
                name, d, t, eta = (
                    self.fmt(name),
                    self.fmt(d),
                    self.fmt(t),
                    self.fmt(eta),
                )

                if name not in self.data:
                    return None
                if d not in self.data[name]:
                    return None
                if t not in self.data[name][d]:
                    return None
                if eta not in self.data[name][d][t]:
                    return None

                return self.data[name][d][t][eta]

        cache = Cache(d).load()

        def actual_sign_descent_manual(t, eta):
            dts = np.copy(d0s)
            for _ in range(t):
                dts = dts - eta * np.sign(dts)
            return np.sum(eigs * dts**2) / loss0, dts

        def actual_sign_descent(t, eta, ret=False):
            val = cache.get("sign", d, t, eta)
            if val is not None and ret is False:
                return val

            T0s = np.floor_divide(d0s, eta).astype(int)
            r = d0s - eta * T0s
            linear_mask = t <= T0s

            linear = d0s - eta * t
            even_cycle_mask = (t - T0s) % 2
            cycle = np.where(even_cycle_mask == 0, r, r - eta)
            dts = np.where(linear_mask, linear, cycle)
            dts = np.where(d0s == 0, 0, dts)

            val = np.sum(eigs * dts**2) / loss0
            cache.put("sign", d, t, eta, val)
            if ret:
                return val, dts
            return val

        def actual_gd(t, eta):
            val = cache.get("gd", d, t, eta)
            if val is not None:
                return val

            loss = np.sum(eigs * (1 - eta * eigs) ** (2 * t) * d0s2)
            val = loss / loss0
            cache.put("gd", d, t, eta, val)
            return val

        def prediction_gd(t, eta, d):
            tau = np.log(2 * t) / np.log(d)
            if tau > 1:
                return 0
            return 1 - tau

        def prediction_sign(t, eta, d):
            tau = 2 * t / np.sqrt(d)
            c = np.pi**2 / 6
            return 1 / (1 + c * tau**2)

        #
        def guess_stepsize_sd(t, d):
            # 4t^2 = tau^2 d
            tau = 2 * t / np.sqrt(d)
            phi = (1 / (1 + 1 / tau**2)) ** (1 / 2)
            return 1 / (np.log(d) * t * 1)

        def guess_stepsize_gd(t, d):
            return 1 / max_eig

        def with_best_stepsize_sd(t, d, ret=False):

            upper = (1 / (np.log(d) * t)) * d
            lower = (1 / (np.log(d) * t * d)) / d

            res = sp.optimize.minimize_scalar(
                lambda x: actual_sign_descent(t, x),
                bounds=(lower, upper),
                method="bounded",
            )
            if ret:
                return actual_sign_descent(t, res.x), res.x
            return actual_sign_descent(t, res.x)

        ts = np.logspace(0, 1 + np.log10(d), base=10, num=50).astype(int)
        ts = 2 * ts - 1  # only odd

        times[d] = ts

        results_gd[d] = [
            actual_gd(t, guess_stepsize_gd(t, d)) for t in tqdm(ts, total=len(ts))
        ]
        results_sd[d] = [with_best_stepsize_sd(t, d) for t in tqdm(ts, total=len(ts))]
        # results_sd = [
        #     actual_sign_descent(t, guess_stepsize_sd(t, d)) for t in tqdm(ts, total=len(ts))
        # ]
        best_ss[d] = [
            with_best_stepsize_sd(t, d, ret=True)[1] for t in tqdm(ts, total=len(ts))
        ]
        preds_gd[d] = [
            prediction_gd(t, guess_stepsize_gd(t, d), d)
            for t in tqdm(ts, total=len(ts))
        ]
        preds_sd[d] = [
            prediction_sign(t, guess_stepsize_sd(t, d), d)
            for t in tqdm(ts, total=len(ts))
        ]

        cache.save()

    # stepsizes = np.logspace(-4, 0, num=50)
    # results_sd_per_ss = [actual_sign_descent_manual(50, ss)[0] for ss in stepsizes]
    # results_sd_per_ss2 = [actual_sign_descent(50, ss) for ss in stepsizes]

    return {
        "times": times,
        "results_gd": results_gd,
        "results_sd": results_sd,
        "preds_gd": preds_gd,
        "best_ss": best_ss,
        "preds_sd": preds_sd,
        # "stepsizes": stepsizes,
        # "results_sd_per_ss": results_sd_per_ss,
        # "results_sd_per_ss2": results_sd_per_ss2,
    }


def postprocess(data):
    return data


def settings(plt):
    plotting.update_style(
        plt, ncols=3, rel_width=1
    )  # , height_to_width_ratio=1 / 1.618)


def make_figure(fig, data, use_logy=True):

    ROW = 1
    COL = 3
    axes = [
        [fig.add_subplot(ROW, COL, COL * r + c + 1) for c in range(COL)]
        for r in range(ROW)
    ]

    d = 31_622
    ts = data["times"][d]
    results_gd = data["results_gd"][d]
    results_sd = data["results_sd"][d]
    preds_gd = data["preds_gd"][d]
    preds_sd = data["preds_sd"][d]

    cmap = plt.get_cmap("YlOrBr")
    color_sd = cmap(0.65)

    axes[0][0].plot(ts, results_gd, label="GD", color="k")
    axes[0][0].plot(ts, results_sd, label="SD", color=color_sd)
    axes[0][0].plot(ts, preds_gd, label="asym.", linestyle="--", color="k")
    axes[0][0].plot(ts, preds_sd, label="asym.", linestyle="--", color=color_sd)

    axes[0][0].legend(
        frameon=False,
        borderpad=0.25,
        borderaxespad=0.25,
        handlelength=1.5,
        handletextpad=0.4,
        labelspacing=0.2,
        fontsize=8,
    )
    axes[0][0].set_xscale("log")
    axes[0][0].set_ylim([0, 1])
    axes[0][0].set_title("Real data")
    axes[0][0].set_xlabel("Optimization horizon $T$")
    axes[0][0].set_ylabel("Normalized loss")

    if False:
        axes[0][1].plot(ts, results_gd, label="GD", color="k")
        axes[0][1].plot(ts, results_sd, label="SD", color=color_sd)
        axes[0][1].plot(ts, preds_gd, label="asym.", linestyle="--", color="k")
        axes[0][1].plot(ts, preds_sd, label="asym.", linestyle="--", color=color_sd)

        axes[0][1].legend(
            frameon=False,
            borderpad=0.25,
            borderaxespad=0.25,
            handlelength=1.5,
            handletextpad=0.4,
            labelspacing=0.2,
            fontsize=8,
        )
        axes[0][1].set_xscale("log")
        axes[0][1].set_yscale("log")
        axes[0][1].set_ylim([10**-4, 10**0])

        axes[0][1].set_title("Real data")
        axes[0][1].set_xlabel("Compute budget $T$")
        axes[0][1].set_ylabel("Normalized loss")

    ## ax 2

    epsilons = [0.75, 0.5, 0.1]
    epsilon_text = {
        0.75: "\\frac{3}{4}",
        0.5: "\\frac{1}{2}",
        0.1: "\\frac{1}{10}",
    }
    cmap = plt.get_cmap("YlOrBr")
    colors = [cmap(0.3), cmap(0.5), cmap(0.7)]

    for eps_idx, epsilon in enumerate(epsilons):

        ds = np.array([1_000, 3_162, 10_000, 31_622]).astype(float)
        vs = np.zeros_like(ds)
        vs_sd = np.zeros_like(ds)

        for i, d in enumerate(data["times"]):
            ts = data["times"][d]
            results_gd = np.array(data["results_gd"][d])
            results_sd = np.array(data["results_sd"][d])
            preds_gd = data["preds_gd"][d]
            preds_sd = data["preds_sd"][d]

            idx = np.where(results_gd <= epsilon)[0]
            vs[i] = ts[idx[0]]
            idx = np.where(results_sd <= epsilon)[0]
            vs_sd[i] = ts[idx[0]]

        ref_idx = 2
        axes[0][1].plot(
            ds,
            vs[ref_idx] * (((ds / ds[ref_idx]) ** (1 - epsilon))),
            linestyle="--",
            dashes=(3, 2),
            linewidth=1,
            color="k",
            alpha=1.0,
        )
        axes[0][2].plot(
            ds,
            vs_sd[ref_idx] * (ds / ds[ref_idx]) ** (1 / 2),
            linestyle="--",
            dashes=(3, 2),
            linewidth=1,
            color="k",
            alpha=1.0,
        )
        axes[0][2].plot(
            ds,
            vs_sd,
            marker="o",
            markersize=4,
            linestyle="",
            label=f"$\\varepsilon\\,{{=}}\\,{epsilon_text[epsilon]}$",
            color=colors[eps_idx],
        )
        axes[0][1].plot(
            ds,
            vs,
            marker="o",
            markersize=4,
            linestyle="",
            label=f"$\\varepsilon\\,{{=}}\\,{epsilon_text[epsilon]}$",
            color=colors[eps_idx],
        )

    axes[0][2].set_xscale("log")
    axes[0][2].set_xlabel("Vocab. size $d$")
    axes[0][2].set_ylabel("Time to $\\epsilon$")
    axes[0][2].set_yscale("log")
    axes[0][1].set_xscale("log")
    axes[0][1].set_xlabel("Vocab. size $d$")
    axes[0][1].set_ylabel("Time to $\\varepsilon$")
    axes[0][1].set_yscale("log")
    axes[0][1].set_title("Gradient descent")
    axes[0][2].set_title("Sign descent")

    axes[0][1].set_xlim([10**2.75, 10**4.75])
    axes[0][2].set_xlim([10**2.75, 10**4.75])
    axes[0][1].set_ylim([10**0, 10**3.5])
    axes[0][2].set_ylim([10**0, 10**3.5])
    axes[0][1].set_yticks([1, 10, 100, 1000])
    axes[0][2].set_yticks([1, 10, 100, 1000])
    axes[0][2].legend(
        frameon=False,
        ncol=3,
        columnspacing=0.2,
        fontsize=7,
        borderpad=0.25,
        borderaxespad=0.25,
        handlelength=1.5,
        handletextpad=0.3,
        labelspacing=0.2,
    )

    fig.tight_layout(pad=0.2)


if __name__ == "__main__":

    if False:
        ds = [1_000, 3_162, 10_000, 31_622]
        for d in ds:
            data = load_data(vocab_size=d)
            data = postprocess(data)
            settings(plt)
            for use_logy in [True, False]:
                fig = plt.figure()
                make_figure(fig, data, use_logy=use_logy)
                plt.savefig(f"figs/real-data_{d}_{use_logy}.pdf")
                plt.close(fig)

    data = load_data()
    data = postprocess(data)
    settings(plt)
    fig = plt.figure()
    make_figure(fig, data)
    plt.savefig(f"figs/real-data_fig1.pdf")
    plt.close(fig)
