import math
from functools import lru_cache

import numpy as np
from scipy.special import zeta


def rescaled_time(ts, d, alpha):
    if alpha < 0.5:
        return ts
    if alpha == 0.5:
        return np.log(4 * ts**2) / np.log(d)
    if alpha == 1.0:
        return 2 * ts / np.sqrt(d)
    if alpha > 0.5:
        return 2 * ts / np.sqrt(d)
    return ts * 0


def rescaled_phi(phi, d, alpha):
    if alpha < 0.5:
        return phi / d
    if alpha == 0.5:
        return np.log(phi) / np.log(d)
    if alpha > 0.5:
        return phi
    return phi


def tau_expr(alpha):
    if alpha > 0.5:
        # return "T\\sqrt{\\frac{\\zeta(2\\alpha)}{d}}"
        return "2T/\\sqrt{d}"
    if alpha == 0.5:
        return "\\log(4T^2)/\\log(d)"
    if alpha < 0.5:
        return "T"
    return "???"


def y_expr(alpha):
    if alpha > 0.5:
        return "$\\phi$"
    if alpha == 0.5:
        return "$x : \phi = d^x$"
    if alpha < 0.5:
        return "$\\phi/d$"
    return "???"


def best_phi_given_tau(tau, alpha):
    if alpha < 1:
        return 1 + 1 / tau**2
        # return (1 + 1 / tau**2) ** (1 / alpha)
    if alpha >= 1:
        return (1 + 1 / tau**2) ** (1 / alpha)


def approx_loss(taus, alpha):
    phi = best_phi_given_tau(taus, alpha)
    # results = ((1 - phi ** (-alpha)) ** 2 + phi ** (-2 * alpha) / taus**2) / zeta(
    #     2 * alpha
    # )
    results = 1 / (1 + zeta(2 * alpha) * (taus**2))
    return results


def tau_rescaling_label(alpha):
    if alpha < 1 / 2:
        return "$\\tau = T$"
    if alpha == 1 / 2:
        return "$\\tau : 4T^2 = d^{\\tau}$"
    if alpha > 1 / 2:
        return "$\\tau : 4T^2 = \\tau^2 d$"


def predicted_phi(taus, alpha):
    if alpha < 0.5:
        return 1 / taus**2
    if alpha == 0.5:
        out = 1 - taus
        out[out < 0] = 0
        out[out > 1] = np.nan
        return out
    if alpha > 0.5:
        return (1 + 1 / taus**2) ** min([1, 1 / alpha])


def predicted_loss(taus, alpha):
    if alpha < 0.5:
        c1 = 1 - 1 / (2 * alpha)
        c2 = alpha / (1 - alpha)
        threshold = math.sqrt((1 - c1) / (4 * c2))
        num = (c1 + c2 * 4 * taus**2) ** (2 * alpha)
        results = num / (4 * taus**2)
        results[taus < threshold] = 2 * alpha * c2
        return results
    if alpha == 0.5:
        return 1 - taus
    if alpha > 0.5:
        return 1 / (1 + zeta(2 * alpha) * taus**2)


def phi_label(alpha):
    if alpha < 0.5:
        return "$\\frac{1}{1+\\frac{\\alpha}{1-\\alpha}4\\tau^2}$"
    if alpha == 0.5:
        return "$1-\\tau$"
    if alpha > 0.5:
        return "$\\frac{1}{1+\\tau^2}$"


def loss_label(alpha):
    if alpha < 0.5:
        return "$\\frac{c_2^{2\\alpha}}{(2\\tau)^{2-4\\alpha}}$"
    if alpha == 0.5:
        return "$1-\\tau$"
    if alpha > 0.5:
        return "$\\frac{1}{1+\\zeta(2\\alpha)\\tau^2}$"


@lru_cache(maxsize=100)
def compute_z(d, alpha):
    return np.sum(np.array([k**-alpha for k in range(1, d + 1)]))


@lru_cache(maxsize=100)
def compute_d0s(d, alpha):
    distances = np.array([k**-alpha for k in range(1, d + 1)])
    z = np.sum(distances)
    return distances / z


def init_loss(d0s):
    return np.sum(d0s**2)


def dkt(d0s, t, eta):
    dkts = np.zeros_like(d0s)
    mask = d0s - t * eta > 0
    dkts[mask] = d0s[mask] - t * eta
    dkts[~mask] = eta / 2
    return dkts


def loss_at(d0s, t, eta):
    return np.sum(dkt(d0s, t, eta) ** 2)
