import dp_accounting
from dp_accounting import rdp, dp_event
import argparse
import numpy as np

def compute_epsilon(T, q, delta, z):
    epsilon_prime = np.sqrt(2 * np.log(1.25 / delta)) / z
    return np.sqrt(T) * q * epsilon_prime

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--noise_multiplier', type=float)
    parser.add_argument('--sampling_method', type=str, default='fixed_batch', help='fixed_batch or poisson')
    parser.add_argument('--batch_size', type=int, default=1000)
    parser.add_argument('--dataset_size', type=int, default=int(1e6))
    parser.add_argument('--n_rounds', type=int, default=1)
    parser.add_argument('--sampling_ratio', type=float, default=0.1)
    parser.add_argument('--delta', type=float, default=1e-6)
    return parser.parse_args()

if __name__ == '__main__':

    args = get_args()

    orders = [1 + x / 10. for x in range(1, 100)] + list(range(12, 64))

    mechanism_event = dp_event.GaussianDpEvent(args.noise_multiplier)
    if args.sampling_method == 'fixed_batch':
        sampling_event = dp_event.SampledWithoutReplacementDpEvent(source_dataset_size=args.dataset_size, sample_size=args.batch_size, event=mechanism_event)
        accountant = rdp.RdpAccountant(orders, dp_accounting.NeighboringRelation(2))
    elif args.sampling_method == 'poisson':
        sampling_event = dp_event.PoissonSampledDpEvent(sampling_probability=args.sampling_ratio, event=mechanism_event)
        accountant = rdp.RdpAccountant(orders)
    composition_event = dp_event.SelfComposedDpEvent(event=sampling_event, count=args.n_rounds)

    accountant.compose(composition_event)

    epsilon = accountant.get_epsilon(args.delta)
    print(f"Epsilon: {epsilon} for Delta={args.delta}")

    end = True