from math import sqrt, pi
import numpy as np
from scipy import linalg
from activation_functions import gaussian_norm, get_kth_hermite_coef


def get_theoretical_values(env):
    """
    Get theoretical values of generalization error and nonrobustness
    """
    reg = env["reg"]
    if reg != 0:
        raise NotImplementedError
    regime = env["regime"]
    sigma = env["sigma"]
    m = env["m"]
    d = env["d"]
    B = env["B"]
    cov = env["cov"]
    W = env["W"]
    WWT = env.get("WWT", W@W.T)
    assert B.shape == (d, d)
    assert cov.shape == (d, d)
    assert WWT.shape == (m, m)
    activ = env["activ"]
    norm_B = linalg.norm(B, ord="fro")
    norm_cov = linalg.norm(cov, ord="fro")
    err_test_ref = 2 * norm_B ** 2
    sob2_ref = 2 * err_test_ref
    if regime == "sgd":
        svals = linalg.svdvals(B)
        svals **= 2
        sob2 = svals[:m].sum() / svals.sum()
        err_test = 1 - sob2
    elif regime == "nt":
        if reg != 0:
            raise NotImplementedError
        # beta = np.trace(B) ** 2 / (d * norm_B ** 2)
        beta = np.trace(B@cov) ** 2 / (norm_B ** 2 * norm_cov ** 2)
        rho = m / d
        cut = max(1 - rho, 0.)
        cat = min(rho, 1.)
        err_test = cut ** 2 * (1 - beta) + cut * beta
        sob2 = (cat + cat ** 2 + (cat - cat ** 2) * beta) / 2
    elif regime == "init":
        norm2_sigma = sigma.norm2
        norm2_sigma_grad = sigma.norm2_grad
        lambda2 = sigma.lambda2
        lambda3 = sigma.lambda3
        err_test = 1 + (norm2_sigma + lambda2 ** 2 * norm_cov ** 2 / 2
                        ) / err_test_ref
        sob2 = norm2_sigma_grad + (lambda3 ** 2 + lambda2 ** 2
                                   ) * norm_cov ** 2
        sob2 /= sob2_ref
    elif regime == "lazy_nt":
        if reg != 0:
            raise NotImplementedError
        env.pop("regime")
        nt_err_test, nt_sob2 = get_theoretical_values(
            {"regime": "nt", **env})
        _, init_sob2 = get_theoretical_values(
            {"regime": "init", **env})
        sob2 = init_sob2 + nt_sob2
        err_test = nt_err_test
    elif regime == "rf":
        norm2_sigma = sigma.norm2
        norm2_sigma_grad = sigma.norm2_grad
        lambda1 = sigma.lambda1
        lambda2 = sigma.lambda2
        lambda3 = sigma.lambda3
        lambda_bar = norm2_sigma - lambda1 ** 2
        assert lambda_bar >= 0.
        lambda_bar += reg
        kappa = lambda2 ** 2 * norm_cov ** 2 * d / 2
        tau = lambda2 * np.trace(B@cov) * sqrt(d)
        lambda_prime_bar = norm2_sigma_grad - lambda1 ** 2
        kappa_prime = lambda3 ** 2 * norm_cov ** 2 * d / 2

        eye = np.eye(m)
        A0 = lambda_bar * eye + lambda1 ** 2 * WWT
        A0inv = linalg.inv(A0)
        psi1 = np.trace(A0inv) / d

        D0 = lambda_prime_bar * eye + (kappa_prime / d + lambda1 ** 2) * WWT
        psi2 = np.trace(A0inv@A0inv@D0) / d

        err_test = 1 - tau ** 2 * psi1 / (1 + kappa * psi1) / err_test_ref
        sob2 = tau ** 2 * (2 * kappa * psi1 ** 2 + psi2)
        sob2 /= (2 * kappa * psi1 + 2) ** 2
        sob2 /= norm_B ** 2
    elif regime == "lazy_rf":
        env.pop("regime")
        rf_err_test, rf_sob2 = get_theoretical_values(
            {"regime": "rf", **env})

        eye = np.eye(m)
        U = env["U"]
        C = env["C"]
        P_lambda = eye - linalg.inv(U + reg * eye)@U
        P2 = P_lambda@P_lambda
        err_test = rf_err_test + np.trace(P2@U) / m
        sob2 = rf_sob2 + np.trace(P2@C) / m
    else:
        err_test = sob2 = np.nan
    return err_test, sob2
