from scipy.special import gamma
import numpy as np
from scipy.stats import norm
from scipy.optimize import root_scalar
import mpmath as mp

def gamma_moments(k, theta, moment = 1):
    if moment == 1:
        return k * theta
    elif moment == 2:
        return k * (k + 1) * theta**2
    else:
        raise ValueError(f"Moment {moment} is not implemented yet")

def truncated_gaussian_moments(m, sigma, moment = 1):
    alpha = m / sigma
    phi = norm.pdf(alpha)           # Standard normal PDF at alpha
    Phi = norm.cdf(alpha)           # Standard normal CDF at alpha

    if moment == 1:
        return m + sigma * (phi / Phi)
    elif moment == 2:
        return m**2 + sigma**2 + m * sigma * (phi / Phi)
    else:
        raise ValueError(f"Moment {moment} is not implemented yet")

def gamma_ratio_mp(a, b):
    return mp.exp(mp.loggamma(a) - mp.loggamma(b))

def generalized_gamma_moments(alpha, beta, p, moment = 1, prec=50):
    mp.mp.dps = prec

    alpha = mp.mpf(alpha)
    beta = mp.mpf(beta)
    p = mp.mpf(p)

    """E[R^2] for GG(alpha, beta, p)."""
    if moment == 1:
        return float(beta**(-1/p) * gamma_ratio_mp((alpha + 2) / p, (alpha + 1) / p))
    elif moment == 2:
        return float(beta**(-2/p) * gamma_ratio_mp((alpha + 3) / p, (alpha + 1) / p))
    else:
        raise ValueError(f"Moment {moment} is not implemented yet")

def chi_moments(sigma, k, moment = 1):
    if moment == 1:
        return sigma * np.sqrt(2) * gamma((k + 1)/2) / gamma(k / 2)
    elif moment == 2:
        return sigma**2 * k
    else:
        raise ValueError(f"Moment {moment} is not implemented yet")


def chi_first_moment_from_second(E_R2, k):
    """
    Given E[R^2] and degrees of freedom k for chi-distribution (scaled),
    compute E[R] assuming R = sigma * chi_k.
    """
    sigma = np.sqrt(E_R2 / k)
    E_R = sigma * np.sqrt(2) * gamma((k + 1) / 2) / gamma(k / 2)
    return E_R

def gamma_first_moment_from_second(E_R2, alpha):
    """
    Given E[R^2] and shape parameter alpha for Gamma(alpha, theta),
    compute E[R].
    """
    theta = np.sqrt(E_R2 / (alpha * (alpha + 1)))
    E_R = alpha * theta
    return E_R

def truncated_gaussian_first_moment_from_second(E_R2, sigma):
    """
    Given E[R^2] and standard deviation sigma for truncated Gaussian,
    compute E[R].
    """
    m = m_from_second_moment(E_R2, sigma)
    return truncated_gaussian_moments(m, sigma, moment = 1)

def m_from_second_moment(E_R2, sigma):
    def equation(m):
        alpha = m / sigma
        phi = norm.pdf(alpha)
        Phi = norm.cdf(alpha)
        return m**2 + m * sigma * (phi / Phi) + sigma**2 - E_R2
    result = root_scalar(equation, bracket=[1e-10, E_R2], method='bisect')
    return result.root if result.converged else None

def cos_angle_cdf(T, z, prec=50):
    mp.mp.dps = prec
    result = mp.betainc((T-1)/2, (T-1)/2, 0, (z+1)/2, regularized=True)
    return float(result)

def power_family_log_ratio_term(w, alpha, T, s, r, prec=50):
    assert r > 0, "r must be positive"
    mp.mp.dps = prec
    inner = 2.0 * s * w / r + (s / r) ** 2
    inner = mp.mpf(inner)

    return float((alpha + 1.0 - T) / 2.0 * mp.log1p(inner))

def find_smallest_mean_of_radial(w, alpha, T, s, epsilon, r_min=1e-6, r_max=1000):
    """
    Find the largest r such that power_family_log_ratio_term(w, alpha, T, s, r) <= -epsilon.
    """
    def f(r):
        return power_family_log_ratio_term(w, alpha, T, s, r) + epsilon

    # Since log_ratio_term increases monotonically with r, the root of f(r) = 0 is the point
    # where log_ratio_term == -epsilon. We search for the root in [r_min, r_max].

    # Sanity check: ensure the function crosses zero in the interval.
    if f(r_min) > 0:
        raise ValueError("No r in the interval satisfies the condition: try a smaller r_min.")
    if f(r_max) < 0:
        raise ValueError("No r in the interval satisfies the condition: try a larger r_max.")
    
    result = root_scalar(f, bracket=[r_min, r_max], method='bisect')
    r_star = result.root if result.converged else None
    return r_star

def find_alpha_lower_bound(failure_rate, epsilon, T, s, noise_budget, prec=50):
    """
        Find the smallest alpha such that the noise budget is satisfied, 
    """
    mp.mp.dps = prec
    start = mp.findroot(lambda w: cos_angle_cdf(T, w, prec=100) - (1-failure_rate), (-1, 1), solver='bisect', tol=failure_rate**2)
    w = float(start)

    def f(alpha):
        return find_smallest_mean_of_radial(w, alpha, T, s, epsilon)**2 - noise_budget
    
    if f(T-2) > 0:
        return T-1
    
    if f(1) < 0:
        return 0

    result = root_scalar(f, bracket=[1e-10, T-2], method='bisect')
    alpha_upper_bound = result.root if result.converged else None

    return alpha_upper_bound