
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import os
import numpy as np
from prompt_graph.utils.rdp_accountant import compute_rdp, get_privacy_spent
import ipdb
from opacus.accountants import create_accountant
from typing import Optional

def get_noise_multiplier(
    *,
    target_epsilon: float,
    target_delta: float,
    sample_rate: float,
    epochs: Optional[int] = None,
    steps: Optional[int] = None,
    accountant: str = "rdp",
    epsilon_tolerance: float = 0.01,
    **kwargs,
) -> float:
    r"""
    Computes the noise level sigma to reach a total budget of (target_epsilon, target_delta)
    at the end of epochs, with a given sample_rate

    Args:
        target_epsilon: the privacy budget's epsilon
        target_delta: the privacy budget's delta
        sample_rate: the sampling rate (usually batch_size / n_data)
        epochs: the number of epochs to run
        steps: number of steps to run
        accountant: accounting mechanism used to estimate epsilon
        epsilon_tolerance: precision for the binary search
    Returns:
        The noise level sigma to ensure privacy budget of (target_epsilon, target_delta)
    """
    MAX_SIGMA = 1e6
    if (steps is None) == (epochs is None):
        raise ValueError(
            "get_noise_multiplier takes as input EITHER a number of steps or a number of epochs"
        )
    if steps is None:
        steps = int(epochs / sample_rate)

    eps_high = float("inf")
    accountant = create_accountant(mechanism=accountant)

    sigma_low, sigma_high = 0, 10
    while eps_high > target_epsilon:
        sigma_high = 2 * sigma_high
        accountant.history = [(sigma_high, sample_rate, steps)]
        eps_high = accountant.get_epsilon(delta=target_delta, **kwargs)
        if sigma_high > MAX_SIGMA:
            raise ValueError("The privacy budget is too low.")

    while target_epsilon - eps_high > epsilon_tolerance:
        sigma = (sigma_low + sigma_high) / 2
        accountant.history = [(sigma, sample_rate, steps)]
        eps = accountant.get_epsilon(delta=target_delta, **kwargs)

        if eps < target_epsilon:
            sigma_high = sigma
            eps_high = eps
        else:
            sigma_low = sigma

    return sigma_high

def clip_and_accumulate(grad_samples, clipping, device):
    summed_grad = torch.zeros(grad_samples[0].shape).to(device)
    for grad_sample in grad_samples:
        per_sample_norms = grad_sample.norm(2, dim=-1)
        if per_sample_norms > clipping:
            clipped_grad_sample = grad_sample*(clipping/(per_sample_norms+1e-6))
        else:
            clipped_grad_sample = grad_sample
        summed_grad += clipped_grad_sample
    return summed_grad
        
def add_noise(summed_grad, noise_multiplier, clipping, device, num_samples):
    noise = torch.normal(
        mean=0,
        std=noise_multiplier*clipping,
        size=summed_grad.shape,
        device=device,
    )
    noised_grad = summed_grad + noise
    scaled_grad = noised_grad
    return scaled_grad

