import copy
import numpy as np
import os
import matplotlib.pyplot as plt

from sgd.tf_simulation import simulate_linear_query, simulate_clipped_gradient_matrix
from OAGN_analysis.ndis_estimator_mp import privacy_estimator_for_Gaussians_mp
from OAGN_analysis.additive_gaussian_privacy import compute_gaussians_privacy_of_linear_query_over_adjacent_datasets

def figure_generation_estimation_vs_computation_by_dim(dim_list=[1, 5, 15, 20, 25], epsilon=1, num_samples=5000000, workers=10, fig_dir = None, file_name="estimation_vs_computation_by_dim.png"):
    estimated_privacy_list = []
    theoretical_privacy_list = []
    for dim in dim_list:
        num_interactions = 1
        num_parameters = dim

        M = simulate_linear_query(num_interactions)
        G = simulate_clipped_gradient_matrix(num_interactions, num_parameters, clip_value=1)

        G_prime = copy.deepcopy(G)
        G_prime[0, :] = 0

        mu1 = (M@G).T.flatten()
        mu2 = (M@G_prime).T.flatten()

        A = np.random.randn(dim, dim)
        Sigma = A @ A.T + 0.01* np.eye(dim)

        Sigma1 = Sigma
        Sigma2 = Sigma

        estimated_privacy = privacy_estimator_for_Gaussians_mp(mu1, Sigma1, mu2, Sigma2, epsilon, num_samples=num_samples, workers=workers, precision=128)

        theoretical_privacy = compute_gaussians_privacy_of_linear_query_over_adjacent_datasets(epsilon, Sigma, M, G, G_prime)

        estimated_privacy_list.append(estimated_privacy)
        theoretical_privacy_list.append(theoretical_privacy)

    if fig_dir is None:
        fig_dir = os.getcwd()
    fig_path = os.path.join(fig_dir, file_name)
    plt.figure(figsize=(10, 6))
    plt.title("Multi-dimensional Additive Gaussian Mechanism Estimation Error vs Dimension")
    plt.xlabel("Dimension")
    plt.ylabel("L1 Error")
    plt.plot(dim_list, np.abs(np.array(estimated_privacy_list)-np.array(theoretical_privacy_list)), label="Estimation Error")
    plt.legend()
    plt.savefig(fig_path)
    plt.close()

    return estimated_privacy_list, theoretical_privacy_list