import pickle
from functools import lru_cache
from pathlib import Path

import numpy as np
from tqdm import tqdm

from spectrum import gdlogreg, sign


class ResultsCache:

    def __init__(self):
        self.results_folder = Path("cache") / "sign-confirm-scaling-3"
        self.results_file = self.results_folder / "results.pk"
        self.results_folder.mkdir(parents=True, exist_ok=True)

    def load(self):
        if self.results_file.exists():
            with open(self.results_file, "rb") as f:
                self.results = pickle.load(f)
        else:
            self.results = {}
        return self

    def save(self):
        with open(self.results_file, "wb") as f:
            pickle.dump(self.results, f)

    @lru_cache(maxsize=1024)
    def fmt(self, x: float):
        return f"{x:.4g}"

    def fmts(self, *xs: float):
        return tuple(self.fmt(x) for x in xs)

    def get(self, alpha, d, t, phi_power):
        return self.get_key(*self.fmts(alpha, d, t, phi_power))

    def get_key(self, alpha, d, t, phi_power):
        try:
            return self.results[alpha][d][t][phi_power]
        except KeyError:
            return None

    def put(self, alpha, d, t, phi_power, value):
        alpha, d, t, phi_power = self.fmts(alpha, d, t, phi_power)
        if alpha not in self.results:
            self.results[alpha] = {}
        if d not in self.results[alpha]:
            self.results[alpha][d] = {}
        if t not in self.results[alpha][d]:
            self.results[alpha][d][t] = {}
        self.results[alpha][d][t][phi_power] = value

    def load_data_as_array(self, alphas, ts, ds, logd_phis):
        n_alpha, n_ds, n_ts, n_etas = len(alphas), len(ds), len(ts), len(logd_phis)
        results_array = np.zeros((n_alpha, n_ds, n_ts, n_etas))
        for i, alpha in tqdm(enumerate(alphas), total=n_alpha, leave=False):
            for k, t in tqdm(enumerate(ts), total=n_ts, leave=False):
                for j, d in tqdm(enumerate(ds), total=n_ds, leave=False):
                    for l, phi_power in tqdm(
                        enumerate(logd_phis), total=len(logd_phis), leave=False
                    ):
                        res = self.get(alpha, d, t, phi_power)
                        if res is None:
                            raise KeyError(
                                f"Missing result for alpha={alpha}, d={d}, t={t}, phi_power={phi_power}"
                            )
                        results_array[i, j, k, l] = res
        return results_array

    def update_cache(self, alphas, ts, ds, logd_phis):
        updated_something = False
        for alpha_idx, alpha in tqdm(
            enumerate(alphas), total=len(alphas), leave=True, desc="alpha"
        ):
            for t_idx, t in tqdm(enumerate(ts), total=len(ts), leave=False, desc="t"):
                for phi_idx, phi_power in tqdm(
                    enumerate(logd_phis), total=len(logd_phis), leave=False, desc="phi"
                ):
                    for d_idx, d in tqdm(
                        enumerate(ds), total=len(ds), leave=False, desc="d"
                    ):
                        if self.get(alpha, d, t, phi_power) is not None:
                            continue

                        phis = float(d) ** logd_phis
                        z = sign.compute_z(d, alpha)
                        d0s = sign.compute_d0s(d, alpha)
                        init_loss_ = sign.init_loss(d0s)
                        eta = 1 / (z * t * phis[phi_idx] ** alpha)
                        self.put(
                            alpha,
                            d,
                            t,
                            phi_power,
                            sign.loss_at(d0s, t, eta) / init_loss_,
                        )
                        updated_something = True
                if updated_something:
                    self.save()
                    updated_something = False


class GDLogRegResultsCache(ResultsCache):

    def __init__(self):
        super().__init__()
        self.results_folder = Path("cache") / "gd-logreg"
        self.results_file = self.results_folder / "results.pk"
        self.results_folder.mkdir(parents=True, exist_ok=True)
        self.results = {}

    def update_cache(self, alphas, ts, ds, sss):
        updated_something = False
        for alpha_idx, alpha in tqdm(
            enumerate(alphas), total=len(alphas), leave=True, desc="alpha"
        ):
            for ss_idx, ss in tqdm(
                enumerate(sss), total=len(sss), leave=False, desc="phi"
            ):
                for d_idx, d in tqdm(
                    enumerate(ds), total=len(ds), leave=False, desc="d"
                ):
                    missing_a_time = any(
                        self.get(alpha, d, t, ss) is None
                        for t in tqdm(ts, total=len(ts), leave=False, desc="t")
                    )
                    if missing_a_time:
                        max_t = max(ts)

                        pis = gdlogreg.compute_pis(d, alpha)
                        losses_lover_time = gdlogreg.compute_losses_over_time(
                            pis, T=max_t, ss=ss
                        )
                        min_loss = gdlogreg.entropy(pis)
                        losses_at_init = gdlogreg.loss_at_0(pis)

                        for t_idx, t in tqdm(
                            enumerate(ts), total=len(ts), leave=False, desc="t"
                        ):
                            self.put(
                                alpha,
                                d,
                                t,
                                ss,
                                (losses_lover_time[t] - min_loss)
                                / (losses_at_init - min_loss),
                            )
                        updated_something = True

                if updated_something:
                    self.save()
                    updated_something = False
