import numpy as np
import math

from helper_functions import get_dyadic_cover_levels

def get_top_level(t, B):
    ''' Computs the level of the greatest interval that 't' is included in. '''
    assert t > B
    return math.floor(math.log2(t - B))

def get_variance_at_t(lamda, epsilon, B, t):
    ''' Computes the variance at time 't'.

    Recall that the nodes in the tree at level 'l' (l=0 being a leaf),
    have noise z~Lap(1 / (epsilon * (l+1)^(lambda - 1)) ).

    Parameters:
    lamda: Parameter for how the noise scales across levels in the tree
    epsilon: Sets the noise at the leaves.
    B: Delay in releasing outputs. Default: 0
    t: Time step at which we want to compute the variance.

    Returns:
    The variance of the output at time 't'.
    '''

    # Outputs before the delay have 0 variance
    if t <= B:
        return 0

    # Compute the greatest interval level 't' is included in
    top_lvl = get_top_level(t, B)
    
    # Epsilon we use for node at each level in the tree
    epsilons_per_level = epsilon * np.power(1 + np.arange(0, top_lvl + 1), lamda - 1)

    # Variance per node is 2*(1/epsilon)^2
    return np.sum( 2 * np.power(epsilons_per_level, -2.0) )

def mean_squared_error_and_lambda_to_epsilon(target_mse, lamda, T, B):
    ''' Compute the corresponding 'epsilon' we need for a given
    lambda to achieve a certain mean-squared error after outputting
    'T' prefix sums.

    Parameter:
    target_mse: The mean-squared error we allow for the first 'T' outputs 
    lamda: The parameter 'lambda' in our algorithm.
    T: Number of outputs.
    B: Delay in releasing outputs.

    Returns:
    The privacy parameter 'epsilon' we use in the paper for our algorithm
    '''

    # The target squared error we want to achieve
    target_squared_error = target_mse * T

    # NOTE: For B != 0, the error is input-dependent
    squared_error_from_bias = 0
    if B > 0:
        print("WARNING: Normalizing by mean-squared error when B>0. Assuming worst-case input of all 1s.")

        # The offset introduces a growing squared error for the first B steps
        squared_error_from_bias += sum( t ** 2  for t in range(1, B + 1))

        # And each future step is off by 'B' on expectation
        squared_error_from_bias += (T - B) * (B ** 2)

        # Make sure that the error from the bias doesn't exceed our target error
        assert squared_error_from_bias < target_squared_error, f"B={B} introduces an error exceeding the target MSE"

    # Compute the contribution from noise assuming epsilon=1
    squared_error_from_noise = sum( get_variance_at_t(lamda=lamda, epsilon=1, B=B, t=t) for t in range(B+1, T+1) )

    # (1/epsilon^2)*squared_error_from_noise + squared_error_from_bias = target_squared_error
    # Compute the epsilon that achieves the target MSE
    return math.sqrt( squared_error_from_noise / (target_squared_error - squared_error_from_bias) )

def privacy_loss_after_exactly_d_steps(d, epsilon, lamda, T, B):
    ''' Computes the worst-case privacy loss for an input at time 't',
    when considering all releases up to and including time 'tau=t+d',
    taken over all valid 't', _but_ only realized at 'tau'.

    The privacy loss for an input at time 't' will depend on which dyadic intervals 
    are used to cover [t, t+d-B], where the privacy is perfect
    for d < B. To find the worst-case privacy for a given 'd', we
    need to consider all relevant 't', which is what the code does.

    Recall that the nodes in the tree at level 'l' (l=0 being a leaf),
    have noise z~Lap(1 / (epsilon * (l+1)^(lambda - 1)) ).

    Parameters:
    d: Time after which to evaluate the privacy of an input
    epsilon: Sets the noise at the leaves.
    lamda: Parameter for how the noise scales across levels in the tree
    T: Length of the stream
    B: Delay in releasing outputs.

    Returns:
    The worst-case privacy for an input, evaluated at the end of
        the time step 'd' steps into the future.
    '''

    # NOTE: Our algorithm releases with a delay of B
    if d < B:
        return 0
    
    # Check over all possible 'alignments' of [a, b]
    delta = d - B # delta := a - b = d - B
    a_max = delta + 1 # biggest 'a' we need to consider
    worst_case_privacy_loss = 0
    for a in range(1, a_max + 1):
        b = a + delta
        if b > T: # Only consider intervals up to [a, T]
            break
        levels_used = get_dyadic_cover_levels(a, b)
        worst_case_privacy_loss = max(
            epsilon * sum( count * math.pow(lvl + 1, lamda - 1) for lvl, count in levels_used.items() ),
            worst_case_privacy_loss)
    
    return worst_case_privacy_loss

def empirical_privacy_loss(d_vec, T, target_mse, lamda, B):
    ''' Computes the worst-case privacy loss for our algorithm
    for every 'd' in 'd_vec'.

    Parameters:
    d_vec: Array of values 'd' for which to compute the privacy expiration.
    T: Length of the stream.
    target_mse: The mean-squared error we allow over all 'T+B' outputs 
    lamda: The parameter 'lambda' in our algorithm.
    B: Delay in releasing outputs.

    Returns:
    An array of values 'y', such that 'y[i]' is the worst-case privacy loss at 'd_vec[i]'.
    '''

    # If we normalize by the MSE for the first 'T' steps,
    # then we also only consider the privacy loss up to 'd = T-1'.
    assert max(d_vec) < T, "Need d < T"

    # Compute the epsilon needed for having MSE 'target_MSE' at time 'T'
    epsilon = mean_squared_error_and_lambda_to_epsilon(target_mse, lamda, T, B)
    print(f'We use epsilon={epsilon} for lambda={lamda}, T={T} to achieve MSE={target_mse}')

    # Compute the worst-case privacy 'd' steps back from a release
    privacy_d_steps_back = np.array(
        [privacy_loss_after_exactly_d_steps(d, epsilon=epsilon, lamda=lamda, T=T, B=B) for d in d_vec])

    # The running-max over the worst-case privacy loss gives the worst-case privacy loss curve
    return np.maximum.accumulate(privacy_d_steps_back)

def privacy_loss_from_g(d_vec, T, target_mse, lamda, B):
    ''' Computes the worst-case privacy loss for our algorithm
    for every 'd' in 'd_vec' using the 'g' from Theorem 1.

    Parameters:
    d_vec: Array of values 'd' for which to compute the privacy expiration.
    T: Length of the stream.
    target_mse: The mean-squared error we allow over all 'T+B' outputs 
    lamda: The parameter 'lambda' in our algorithm.
    B: Delay in releasing outputs.

    Returns:
    An array of values 'y', such that 'y[i]' is the worst-case privacy loss  at 'd_vec[i]'.
    '''

    # If we normalize by the MSE for the first 'T' steps,
    # then we also only consider the privacy loss up to 'd = T-1'.
    assert max(d_vec) < T, "Need d < T"

    # Compute the epsilon needed for having variance 'target_variance' at time 'T'
    epsilon = mean_squared_error_and_lambda_to_epsilon(target_mse, lamda, T, B)

    # Compute the privacy expiration we state
    return epsilon * np.array([our_g(d, lamda=lamda, B=B) for d in d_vec])

def our_g(d, lamda, B):
    ''' Compute the privacy expiration function 'g' we support in the paper.

    'g' is derived from the fact that an interval of size
    'd' can be covered by at most 2 dyadic intervals per level in the tree,
    where each level is at most log(d-B+1).

    d: Point at which to evaluate the function.
    lamda: Parameter for how the noise scales across levels in the tree
    B: Delay in releasing outputs.

    Return:
    The function g evaluated at d, g(d).
    '''

    if d < B:
        return 0

    top_level = math.floor(math.log2(d-B+1))
    return 2 * np.sum(
            np.power( 1 + np.arange(0, top_level + 1, dtype=float), lamda - 1)
    )
