from OAEGN_analysis.spherical_generalized_gamma import gen_optimizer_params, compute_empirical_optimal_generalized_gamma_parameters_under_privacy_constraint, compute_spherical_generalized_gamma_privacy, generalized_gamma_compute_beta_for_given_noise_budget
from OAEGN_analysis.radial_rv import generalized_gamma_moments
from scipy.optimize import root_scalar
from OAEGN_analysis.spherical_generalized_gamma import privacy_gaussian_mechanis_via_analytic_formula
from datetime import datetime
import os
import sys
import time
from utils import setup_logging
import numpy as np

# Navigate to the parent directory of the project structure
project_dir = os.path.abspath(os.path.join(os.getcwd(), '../'))
src_dir = os.path.join(project_dir, 'src')
data_dir = os.path.join(project_dir, 'data')
fig_dir = os.path.join(project_dir, 'fig')
log_dir = os.path.join(project_dir, 'log')  
os.makedirs(fig_dir, exist_ok=True)
os.makedirs(data_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)


timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_file = os.path.join(log_dir, f'find_best_adv_{timestamp}.log')
logger = setup_logging(log_file)

def binary_search_ell2_error_for_delta(epsilon, T, target_delta, database_sensitivity=1.0, 
                                        error_lower=1e-6, error_upper=1e6, 
                                        tolerance=1e-10, max_iterations=50):
    """
    Binary search to find the smallest ell2_error that achieves target_delta within tolerance.
    
    Parameters:
    -----------
    epsilon : float
        Privacy parameter epsilon
    T : int
        Dimensionality
    target_delta : float
        Target delta value to achieve
    database_sensitivity : float
        Database sensitivity (default 1.0)
    error_lower : float
        Lower bound for binary search on ell2_error
    error_upper : float
        Upper bound for binary search on ell2_error
    tolerance : float
        Tolerance for delta convergence (default 1e-10)
    max_iterations : int
        Maximum number of binary search iterations
        
    Returns:
    --------
    dict : Dictionary containing:
        - 'ell2_error': The optimal ell2 error found
        - 'delta': The delta achieved with this error
        - 'beta': The beta parameter used
        - 'iterations': Number of iterations used
        - 'converged': Whether the search converged
    """
    mu_0 = np.array([0])
    mu_1 = np.array([database_sensitivity])
    
    # Use alpha = T-1 and p = 1 as specified
    alpha = T - 1
    p = 1
    
    ell2_error_lower = error_lower
    ell2_error_upper = error_upper
    
    best_result = None
    
    for iteration in range(max_iterations):
        # Try midpoint
        ell2_error_mid = (ell2_error_lower + ell2_error_upper) / 2
        
        # Compute beta for this error budget
        beta = generalized_gamma_compute_beta_for_given_noise_budget(alpha, p, ell2_error_mid)
        
        # Compute the resulting delta
        delta_achieved = compute_spherical_generalized_gamma_privacy(
            epsilon, alpha, beta, p, T, mu_0, mu_1
        )
        
        logger.info(f"Iteration {iteration + 1}: ell2_error={ell2_error_mid:.10e}, ell2_error_lower={ell2_error_lower:.10e}, ell2_error_upper={ell2_error_upper:.10e}, "
              f"delta_achieved={delta_achieved:.10e}, target_delta={target_delta:.10e}, "
              f"diff={abs(delta_achieved - target_delta):.10e}")
        
        # Store the best result so far (closest to target)
        if delta_achieved < target_delta: # abs(delta_achieved - target_delta) < abs(best_result['delta'] - target_delta):
            best_result = {
                'ell2_error': ell2_error_mid,
                'delta': delta_achieved,
                'beta': beta,
                'iterations': iteration + 1,
                'converged': False
            }
                    
        # Check convergence
        if abs(delta_achieved - target_delta) < tolerance: #and best_result is not None:
            best_result['converged'] = True
            logger.info(f"\nConverged after {iteration + 1} iterations!")
            break
        
        # Update bounds
        # If delta_achieved > target_delta, we have too little privacy (too little noise)
        # So we need to increase the error (move right)
        if delta_achieved > target_delta:
            ell2_error_lower = ell2_error_mid
        else:
            # delta_achieved < target_delta, we have too much privacy (too much noise)
            # So we need to decrease the error (move left)
            ell2_error_upper = ell2_error_mid
        
        # Check if bounds are too close
        if (ell2_error_upper - ell2_error_lower)/ell2_error_upper < 0.005: #and best_result is not None:
            logger.info(f"\nStopping: Search bounds too close at iteration {iteration + 1}")
            break
        
    logger.info("\nBinary search complete.")
    
    if best_result is None:
        best_result = {
                'ell2_error': np.inf,
                'delta': delta_achieved,
                'beta': beta,
                'iterations': iteration + 1,
                'converged': False
            }
    
    if not best_result['converged']:
        logger.info(f"\nWarning: Did not converge within {max_iterations} iterations")
        logger.info(f"Best result achieved delta={best_result['delta']:.10e} "
              f"(diff from target: {abs(best_result['delta'] - target_delta):.10e})")
    
    return best_result

sens = [1]
dims = [2]
eps = [.01,.1,1,10]

error_metric = 'second_moment'
prec = 50

max_iterations = 20

for database_sensitivity in sens:
    for T in dims:
        opt_params = gen_optimizer_params(T)
        for epsilon in eps:
            delta_lower = 1e-7
            delta_upper = 1e-1
            for iteration in range(max_iterations):
                delta_mid = (delta_lower + delta_upper) / 2
                logger.info(f"Delta search iteration {iteration + 1}: delta_lower={delta_lower}, delta_upper={delta_upper}, delta_mid={delta_mid}")
                alpha_opt, p_opt, beta_opt, delta_opt = compute_empirical_optimal_generalized_gamma_parameters_under_privacy_constraint(T=T, epsilon=epsilon, delta=delta_mid, database_sensitivity=database_sensitivity, error_metric=error_metric, prec=prec, opt_params=opt_params, logger=logger)#, tol=1e-4)
                logger.info(f"alpha_opt: {alpha_opt}, p_opt: {p_opt}, beta_opt: {beta_opt}, delta_opt: {delta_opt}")
                
                result = root_scalar(lambda sigma: privacy_gaussian_mechanis_via_analytic_formula(sigma, epsilon)-delta_mid, bracket=[1e-6, 1e6], method='bisect', xtol=1e-12)
                sigma = result.root if result.converged else None
                
                gaussian_error = T*(sigma**2)

                generalized_gamma_error = generalized_gamma_moments(alpha_opt, beta_opt, p_opt, moment=2)

                ell2_result = binary_search_ell2_error_for_delta(
                    epsilon=epsilon,
                    T=T,
                    target_delta=delta_mid,
                    database_sensitivity=database_sensitivity,
                    tolerance=1e-10,
                    error_lower=0.0,
                    error_upper=gaussian_error
                )
                
                if (ell2_result['ell2_error'] - generalized_gamma_error)/ell2_result['ell2_error'] < (gaussian_error - generalized_gamma_error)/gaussian_error:
                    logger.info("make delta bigger")
                    delta_lower = delta_mid
                else:
                    logger.info("make delta smaller")
                    delta_upper = delta_mid
                
            logger.info("====================================================================")
            logger.info(f"Database sensitivity: {database_sensitivity}, Dimension: {T}, Epsilon: {epsilon}")
            logger.info(f"best Delta: {delta_mid}")
            logger.info(f"gaussian_error: {gaussian_error}, generalized_gamma_error: {generalized_gamma_error}, ell_2 error: {ell2_result['ell2_error']}")