from scipy.optimize import root_scalar
from OAGN_analysis.privacy_analysis import privacy_gaussian_mechanis_via_analytic_formula
from OAEGN_analysis.privacy_analysis import compute_spherical_generalized_gamma_privacy
from OAEGN_analysis.radial_rv import chi_moments, generalized_gamma_moments
from skopt import Optimizer
from skopt.space import Real
import numpy as np
from concurrent.futures import ProcessPoolExecutor
from functools import partial
from utils import red, green, yellow, blue, magenta, cyan, bold
import mpmath as mp

def gen_optimizer_params(T=None):
    params = {
        "p_lower": 1.0,
        "p_upper": 3.0,
        "alpha_lower": 0.0,
        "alpha_upper": None,
        "BO_params": {
            "base_estimator": "gp",
            "acq_func": "EI",
            "acq_optimizer": "auto",
            "batch_size": 10,
            "n_iter": 60,
        },
        "n_workers": 96,
    }

    if T is not None:
        params["alpha_upper"] = T-1
    
    return params

def generalized_gamma_compute_beta_for_given_noise_budget(alpha, p, noise_budget, prec=50, atol=1e-12):
    result = root_scalar(lambda beta: generalized_gamma_moments(alpha, beta, p, moment=2, prec=prec) - noise_budget, bracket=[1e-10, 1e4], method='bisect', xtol=atol)        
    return result.root if result.converged else None


def gneeralized_gamma_compute_delta_for_given_noise_budget(alpha, p, noise_budget, T, epsilon, mu_0, mu_1, method='exact', prec=50, atol=1e-12):
    beta = generalized_gamma_compute_beta_for_given_noise_budget(alpha, p, noise_budget, prec=prec, atol=atol)
    return compute_spherical_generalized_gamma_privacy(epsilon, alpha, beta, p, T, mu_0, mu_1, prec=prec, method=method)


def _generalized_gamma_optimization_wrapper(x, noise_budget, T, epsilon, mu_0, mu_1, method, prec, tol):
    """Wrapper function for generalized gamma optimization that can be pickled."""
    return gneeralized_gamma_compute_delta_for_given_noise_budget(x[0], x[1], noise_budget, T, epsilon, mu_0, mu_1, method, prec, tol)

def _compute_empirical_optimal_generalized_gamma_parameters_under_error_constraint(
    T, epsilon, noise_budget, database_sensitivity, error_metric='second_moment', method='exact', prec=50, tol=1e-12, logger=None,
    opt_params=None, delta_goal=None
    ):  
    assert error_metric in ['second_moment'], "invalid error metric"
    assert T > 1, "T must be greater than 1"
    assert epsilon > 0, "epsilon must be greater than 0"
    assert noise_budget > 0, "noise_budget must be greater than 0"
    assert database_sensitivity > 0, "database_sensitivity must be greater than 0"
    assert method in ['exact', 'approximate'], "invalid method"
    assert prec > 0, "prec must be greater than 0"
    prec = int(prec)

    mu_0 = np.array([0])
    mu_1 = np.array([database_sensitivity])

    if opt_params["alpha_upper"] is None:
        opt_params["alpha_upper"] = T-1

    space = [
        Real(opt_params["alpha_lower"], opt_params["alpha_upper"], name="alpha"),
        Real(opt_params["p_lower"], opt_params["p_upper"], name="p")
    ]
    batch_size = opt_params["BO_params"]["batch_size"]
    n_iter = opt_params["BO_params"]["n_iter"]
    n_workers = opt_params["n_workers"]

    opt = Optimizer(
        dimensions=space,
        base_estimator=opt_params["BO_params"]["base_estimator"],
        acq_func=opt_params["BO_params"]["acq_func"],
        acq_optimizer=opt_params["BO_params"]["acq_optimizer"],
    )

    # Use functools.partial to create a picklable function
    f = partial(_generalized_gamma_optimization_wrapper, 
                noise_budget=noise_budget, T=T, epsilon=epsilon, 
                mu_0=mu_0, mu_1=mu_1, method=method, prec=prec, tol=tol)

    for i in range(n_iter // batch_size):
        # 1) ask for a batch of points
        batch = opt.ask(n_points=batch_size)   # e.g. [[β1,p1], [β2,p2], …]
        
        with ProcessPoolExecutor(max_workers=min(n_workers, batch_size)) as executor:
            ys = list(executor.map(f, batch))

        # 2) feed back all at once
        opt.tell(batch, ys)

        if logger is not None:
            for (alpha, p), y in zip(batch, ys):
                logger.info(f"Noise budget {noise_budget} iteration {i}: alpha={alpha:.4f}, p={p:.4f}, delta={y:.16f}")
                
        if delta_goal is not None and np.min(opt.yi) <= delta_goal - tol:
            if logger is not None:
                logger.info(f"Early stopping at iteration {i} as delta_goal {delta_goal} is reached with delta {np.min(opt.yi)}")
            break

    idx_best = np.argmin(opt.yi)       # index of lowest observed y
    alpha_opt, p_opt = opt.Xi[idx_best] # the (β, p) achieving that
    delta_opt = opt.yi[idx_best]

    return alpha_opt, p_opt, delta_opt

def compute_empirical_optimal_generalized_gamma_parameters_under_privacy_constraint(
    T, epsilon, delta, database_sensitivity, error_metric='second_moment', method='exact', prec=50, tol=1e-12, logger=None,
    opt_params=None, max_iters=50, noise_tol=0.05
    ):
    assert error_metric in ['second_moment'], "invalid error metric"
    assert T > 1, "T must be greater than 1"
    assert epsilon > 0, "epsilon must be greater than 0"
    assert delta > 0, "delta must be greater than 0"
    assert database_sensitivity > 0, "database_sensitivity must be greater than 0"
    assert method in ['exact', 'approximate'], "invalid method"
    assert prec > 0, "prec must be greater than 0"
    prec = int(prec)

    result = root_scalar(lambda sigma: privacy_gaussian_mechanis_via_analytic_formula(sigma, epsilon, database_sensitivity)-delta, bracket=[1e-6, 1e6], method='bisect', xtol=tol) 
    sigma = result.root if result.converged else None

    if sigma is None:
        raise ValueError("Failed to find a solution for sigma")
    else:
        if error_metric == 'second_moment':
            c_upper = chi_moments(sigma, T, moment=2)
    
    if logger is not None:
        logger.info(f"Computed noise budget upper bound: {c_upper} for epsilon={epsilon}, delta={delta}, T={T}, database_sensitivity={database_sensitivity}, error_metric={error_metric}")

    c_lower = 0.0
    
    c_opt = c_upper

    alpha_opt, p_opt, delta_opt = _compute_empirical_optimal_generalized_gamma_parameters_under_error_constraint(
            T=T, epsilon=epsilon, noise_budget=c_upper, database_sensitivity=database_sensitivity, 
            error_metric=error_metric, method=method, prec=prec, tol=tol, logger=logger, opt_params=opt_params, delta_goal=delta
        )
    beta_opt = generalized_gamma_compute_beta_for_given_noise_budget(alpha_opt, p_opt, noise_budget=c_upper)
    if logger is not None:
        initial_text = f"Initial result: alpha={alpha_opt:.6f}, p={p_opt:.6f}, delta={delta_opt:.6e}"
        logger.info(f"{red(initial_text)}")
    
    # Binary search to find optimal noise budget
    for iteration in range(max_iters):
        c_mid = (c_lower + c_upper) / 2
        
        if logger is not None:
            iteration_text = f"Binary search iteration {iteration + 1}/{max_iters}:"
            c_lower_text = f"{c_lower:.6f}"
            c_upper_text = f"{c_upper:.6f}"
            c_mid_text = f"{c_mid:.6f}"
            logger.info(f"{cyan(iteration_text)} c_lower={yellow(c_lower_text)}, c_upper={yellow(c_upper_text)}, c_mid={bold(c_mid_text)}")
        
        # Compute optimal parameters for current noise budget
        tmp_alpha_opt, tmp_p_opt, tmp_delta_opt = _compute_empirical_optimal_generalized_gamma_parameters_under_error_constraint(
            T=T, epsilon=epsilon, noise_budget=c_mid, database_sensitivity=database_sensitivity, 
            error_metric=error_metric, method=method, prec=prec, tol=tol, logger=logger, opt_params=opt_params, delta_goal=delta
        )
        
        if logger is not None:
            iteration_text = f"Iteration {iteration + 1}/{max_iters}'s delta_opt:"
            delta_opt_text = f"{tmp_delta_opt:.6e}"
            target_delta_text = f"{delta:.6e}"
            difference_text = f"{abs(tmp_delta_opt - delta):.6e}"
            error_text = f"Error: {c_mid:.6f}"
            p_text = f"p: {tmp_p_opt:.6f}"
            logger.info(f"{blue(iteration_text)} {green(delta_opt_text)}, {magenta('target delta:')} {green(target_delta_text)}, {red('difference:')} {bold(difference_text)}, {green(error_text)}, {yellow(p_text)}")
        
        # Update search bounds based on whether delta_opt is too high or too low
        if tmp_delta_opt > delta:
            c_lower = c_mid
        else:
            alpha_opt = tmp_alpha_opt
            p_opt = tmp_p_opt
            delta_opt = tmp_delta_opt
            beta_opt = generalized_gamma_compute_beta_for_given_noise_budget(alpha_opt, p_opt, noise_budget=c_mid)
            c_upper = c_mid
            c_opt = c_mid
            
        # Check stopping conditions
        if abs(tmp_delta_opt - delta) < tol or (c_upper - c_lower)/c_upper < noise_tol:
            break
    
    if iteration == max_iters - 1:
        # Loop completed without breaking (hit max_iters)
        if logger is not None:
            convergence_text = f"Binary search completed after {max_iters} iterations without convergence"
            logger.info(f"{blue(convergence_text)}")
    
    if logger is not None:
        final_text = f"Final result: error={c_opt:.6f}, alpha={alpha_opt:.6f}, p={p_opt:.6f}, beta={beta_opt:.6f}, delta={delta_opt:.6e}"
        logger.info(f"{red(final_text)}")
    
    return alpha_opt, p_opt, beta_opt, delta_opt


def generalized_gamma_pdf_mp(r, alpha, beta, p):
    if r < 0:
        return 0
    norm = p * beta ** ((alpha + 1) / p) / mp.gamma((alpha + 1) / p)
    return norm * r ** alpha * mp.exp(-beta * r ** p)

def cos_angle_pdf_mp(w, T):
    if not (-1 <= w <= 1):
        return 0
    numerator = mp.gamma(T / 2)
    denominator = mp.sqrt(mp.pi) * mp.gamma((T - 1) / 2)
    return numerator / denominator * (1 - w ** 2) ** ((T - 3) / 2)