import numpy as np
from dp_accounting import pld
from dp_accounting.dp_event import GaussianDpEvent
from scipy.optimize import root_scalar
from OAGN_analysis.accounting_gaussian import delta_Gaussian_mech_compositions_analytic

def compute_accuracy_of_additive_gaussian_mechanism_over_linear_query(Sigma, method="average_frobenius_norm", implementation="numpy"):
    """
    Compute the accuracy of the additive Gaussian mechanism over a linear query:
    takes as input the covariance matrix Sigma for the Gaussian distributions, and the method to compute the accuracy.
    """

    assert Sigma.ndim == 2 and Sigma.shape[0] == Sigma.shape[1], "The covariance matrix must be a square matrix"

    if implementation == "numpy":
        return compute_accuracy_of_additive_gaussian_mechanism_over_linear_query_np(Sigma, method)
    else:
        raise ValueError("Invalid implementation")


def compute_accuracy_of_additive_gaussian_mechanism_over_linear_query_np(Sigma, method="average_frobenius_norm"):
    """
    Compute the accuracy of the additive Gaussian mechanism over a linear query:
    takes as input the covariance matrix Sigma for the Gaussian distributions, and the method to compute the accuracy.
    """

    assert Sigma.ndim == 2 and Sigma.shape[0] == Sigma.shape[1], "The covariance matrix must be a square matrix"

    if method == "average_frobenius_norm":
        return np.trace(Sigma)
    else:
        raise ValueError("Invalid method")

def binary_tree_total_popcount(n):
    return sum(bin(i).count('1') for i in range(1, n+1))

def tree_pld_epsilon(database_size, database_sensitivity, sigma, target_delta):
    """
    Computes the (ε, δ)-DP guarantee of the binary-tree mechanism on an array of size database_size, using Gaussian noise with multiplier sigma,
    via the PLD accountant.

    Given delta and the size of the database, noise multiplier sigma, and compute the corresponding epsilon.
    """
    # Number of Gaussian “events” you actually compose.
    # In the binary-tree mechanism each prefix-sum query incurs
    # noise at O(log n) distinct tree nodes.
    num_events = int(np.ceil(np.log2(database_size)))
    
    # 1) Create a fresh PLD accountant
    accountant = pld.PLDAccountant()
    
    # 2) Describe your mechanism: here it’s "num_events" independent
    #    applications of a Gaussian mechanism with noise‐multiplier sigma
    accountant.compose(
        GaussianDpEvent(noise_multiplier=sigma/(database_sensitivity)),
        count=num_events
    )
    
    # 3) Ask for the ε that pairs with your target_delta
    eps = accountant.get_epsilon(target_delta)
    return eps

def tree_pld_find_sigma_for_epsilon(database_size, database_sensitivity, epsilon_target, delta, tol=1e-5, min_sigma=1e-5, max_sigma=1e5):
    """
    Simple binary‐search to find the smallest sigma giving
    ε ≤ epsilon_target at fixed delta.
    """
    lo, hi = min_sigma, max_sigma
    while hi - lo > tol:
        mid = (lo + hi) / 2
        if tree_pld_epsilon(database_size, database_sensitivity, mid, delta) > epsilon_target:
            lo = mid
        else:
            hi = mid
    return hi

def compute_error_of_binary_Gaussian_mechanism_over_prefix_sum_query(database_size, epsilon, delta, database_sensitivity, method="average_frobenius_norm", noise_accounting_method="tree_pld", num_parameters=1):
    if noise_accounting_method == "tree_pld":
        sigma = tree_pld_find_sigma_for_epsilon(database_size, database_sensitivity, epsilon, delta)
        node_variance = sigma**2
    elif noise_accounting_method == "vanilla_gaussian":
        L = int(np.ceil(np.log2(database_size)))
        node_variance = 2*L*database_sensitivity*np.log(1/delta)/(epsilon**2)
    else:
        raise ValueError("Invalid noise accounting method")

    if method == "average_frobenius_norm":
        popcount = binary_tree_total_popcount(database_size)
        return popcount * node_variance * num_parameters
    else:
        raise ValueError("Invalid method")


def noise_budget_for_k_gaussian(epsilon, delta, k, dim, sensitivity=1.0):
    solve = root_scalar(
        lambda sigma: delta_Gaussian_mech_compositions_analytic(epsilon, sigma, k, s=sensitivity) - delta,
        bracket=[1e-4, 1e4],
        method="bisect",
        xtol=1e-12,
    )
    sigma = solve.root
    return dim * sigma**2  # per-release budget; multiply by k if you need the sum



