import numpy as np
from scipy.special import erf
from numpy.linalg import inv as mat_inv

from scipy.linalg import sqrtm

def compute_gaussian_mechanism_privacy_over_query(epsilon, sigma, query1, query2):
    """
    Compute the privacy of the additive Gaussian mechanism: takes as input the privacy parameter epsilon,
    the standard deviation sigma, where the Gaussian noise distribution is N(0, I * sigma^2),
    , and the two query functions over a adjacent datasets.
    """

    assert query1.shape[0] == query2.shape[0], "The dimension of the two query functions must be the same"
    assert query1.ndim == 1 and query2.ndim == 1, "The query functions must be vectors"
    assert epsilon >= 0, "The privacy parameter epsilon must be non-negative"
    assert sigma > 0, "The standard deviation sigma must be positive"


    t = np.linalg.norm(query1 - query2, ord=2)/sigma

    term1 = (1 - np.exp(epsilon)) / 2
    sqrt8 = np.sqrt(8)
    sqrt2 = np.sqrt(2)
    term2 = 0.5 * (
        erf((-epsilon / (sqrt2*t)) + (t / sqrt8)) -
        np.exp(epsilon) * erf((-epsilon / (sqrt2*t)) - (t / sqrt8))
    )
    return max(0, term1 + term2)

def compute_gaussians_privacy_over_query(epsilon, Sigma, query1, query2):
    """
    Compute the privacy of a pair of Gaussian distributions with different means but the same covariance matrix:
    takes as input the privacy parameter epsilon, the covariance matrix Sigma for the Gaussian distributions, and the two query functions over a adjacent datasets.
    , and the two query functions (vector) over a adjacent datasets.
    """

    assert query1.shape[0] == query2.shape[0], "The dimension of the two query functions must be the same"
    assert query1.ndim == 1 and query2.ndim == 1, "The query functions must be vectors"
    assert Sigma.ndim == 2, "The covariance matrix must be a matrix"
    assert Sigma.shape[0] == Sigma.shape[1], "The covariance matrix must be square"
    assert Sigma.shape[0] == query1.shape[0], "The dimension of the covariance matrix must be the same as the dimension of the query functions"

    inv_Sigma = mat_inv(Sigma)
    mu = query1 - query2
    Mahalanobis_norm_of_difference_vector = np.linalg.norm(sqrtm(inv_Sigma) @ mu, ord=2)

    return additive_gaussian_privacy_function(Mahalanobis_norm_of_difference_vector, epsilon)


def compute_gaussians_privacy_of_linear_query_over_adjacent_datasets(epsilon, Sigma, linear_query, D1, D2):
    """
    Compute the privacy of a linear query over a pair of adjacent datasets, where the noised mechanism is an additive Gaussian mechanism with covariance matrix Sigma:
    takes as input the privacy parameter epsilon, the covariance matrix Sigma for the Gaussian distributions, and the linear query function, and the two adjacent datasets.
    """

    assert D1.shape[0] == D2.shape[0], "The number of rows of the two datasets must be the same"
    assert D1.shape[1] == D2.shape[1], "The number of columns of the two datasets must be the same"
    assert D1.shape[0] == linear_query.shape[0], "The number of rows of the linear query must be the same as the number of rows of the datasets"
    assert Sigma.ndim == 2, "The covariance matrix must be a matrix"
    assert Sigma.shape[0] == Sigma.shape[1], "The covariance matrix must be square"
    

    mu1 = (linear_query@D1).T.flatten()
    mu2 = (linear_query@D2).T.flatten()

    assert Sigma.shape[0] == len(mu1), "The dimension of the covariance matrix must be the same as the dimension of the linear query result"

    return compute_gaussians_privacy_over_query(epsilon, Sigma, mu1, mu2)

def analytical_compute_privacy_of_additive_gaussian_mechanism_over_linear_query_with_unconstrained_neighbors(epsilon, Sigma, linear_query, database_sensitivity=1):
    """
    Compute the privacy of the additive Gaussian mechanism over a linear query:
    takes as input the privacy parameter epsilon, the covariance matrix Sigma for the Gaussian distributions, the linear query function,
    and the sensitivity of the database, namely the maximum over ||D - D'||_F over all adjacent datasets D and D'.
    """

    assert linear_query.ndim == 2 and linear_query.shape[0] == linear_query.shape[1], "The linear query must be a square matrix"
    assert Sigma.ndim == 2 and Sigma.shape[0] == Sigma.shape[1], "The covariance matrix must be a square matrix"
    assert database_sensitivity > 0, "The sensitivity must be positive"
    assert Sigma.shape[0] % linear_query.shape[0] == 0, "The dimension of the covariance matrix must be a multiple of the dimension of the linear query"
    d = Sigma.shape[0] // linear_query.shape[0]

    A = np.kron(np.eye(d), linear_query)
    B = A.T@mat_inv(Sigma)@A

    eigenvalues, eigenvectors = np.linalg.eigh(B)
    max_eigenvalue = np.max(eigenvalues)

    max_Mahalanobis_norm_of_difference_vector = np.sqrt(max_eigenvalue) * database_sensitivity

    return additive_gaussian_privacy_function(max_Mahalanobis_norm_of_difference_vector, epsilon)


def analytical_compute_optimal_noise_covariance_for_linear_query_with_unconstrained_neighbors(epsilon, delta, linear_query, num_parameters, database_sensitivity=1, tolerance=1e-8, norm_max=1e3):
    """
    Generate the optimal additive Gaussian mechanism over a linear query:
    takes as input the privacy parameter epsilon, delta, the linear query function,
    and the sensitivity of the database, namely the maximum over ||D - D'||_F over all adjacent datasets D and D'.
    Output the noise_covariance matrix Sigma of the optimal additive Gaussian mechanism. 
    """

    max_Mahalanobis_norm_of_difference_vector = compute_largest_Mahalanobis_norm_using_additive_gaussian_privacy_function(epsilon, delta, tolerance, norm_max)

    if max_Mahalanobis_norm_of_difference_vector is None or max_Mahalanobis_norm_of_difference_vector == 0:
        raise ValueError("No valid Mahalanobis norm found for the given privacy parameters (epsilon, delta).")
    
    z = (max_Mahalanobis_norm_of_difference_vector/database_sensitivity)**2

    Sigma = (1.0 / z) * np.kron(np.eye(num_parameters), linear_query @ linear_query.T)
    error = num_parameters * (1.0 / z) * (np.linalg.norm(linear_query, ord='fro')**2)

    return Sigma, error


def additive_gaussian_privacy_function(Mahalanobis_norm_of_difference_vector, epsilon):
    """
    Compute the privacy of the additive Gaussian mechanism over
    """

    assert Mahalanobis_norm_of_difference_vector > 0, "The Mahalanobis norm of the difference between the two query results must be positive"
    assert epsilon >= 0, "The privacy parameter epsilon must be non-negative"
    critical_term1 = epsilon / (np.sqrt(2)*Mahalanobis_norm_of_difference_vector)
    critical_term2 = Mahalanobis_norm_of_difference_vector / np.sqrt(8)

    erf_term1 = erf(-critical_term1 + critical_term2)
    erf_term2 = erf(-critical_term1 - critical_term2)

    term1 = (1 - np.exp(epsilon)) / 2
    term2 = 0.5 * (erf_term1 - np.exp(epsilon) * erf_term2)

    return max(0, term1 + term2)

def compute_largest_Mahalanobis_norm_using_additive_gaussian_privacy_function(epsilon, delta, tolerance=1e-16, norm_max=1e3):
    """
    Compute the largest Mahalanobis norm such that additive_gaussian_privacy_function(z, epsilon) <= delta
    """
    assert epsilon >= 0, "The privacy parameter epsilon must be non-negative"
    assert delta >= 0 and delta <= 1, "The delta must be non-negative and less than or equal to 1"

    Mahalanobis_norm_of_difference_vector_low = tolerance  # z must be > 0
    Mahalanobis_norm_of_difference_vector_high = norm_max

    # Check if even the largest Mahalanobis norm is not enough
    if additive_gaussian_privacy_function(Mahalanobis_norm_of_difference_vector_high, epsilon) < delta:
        return None  # No such Mahalanobis norm exists

    while Mahalanobis_norm_of_difference_vector_high - Mahalanobis_norm_of_difference_vector_low > tolerance:
        Mahalanobis_norm_of_difference_vector_mid = (Mahalanobis_norm_of_difference_vector_low + Mahalanobis_norm_of_difference_vector_high) / 2
        val = additive_gaussian_privacy_function(Mahalanobis_norm_of_difference_vector_mid, epsilon)
        if val <= delta:
            Mahalanobis_norm_of_difference_vector_low = Mahalanobis_norm_of_difference_vector_mid  # Try larger Mahalanobis norm
        else:
            Mahalanobis_norm_of_difference_vector_high = Mahalanobis_norm_of_difference_vector_mid  # Try smaller Mahalanobis norm

    return Mahalanobis_norm_of_difference_vector_low




