import math
import numpy as np
import sys

import matplotlib.pyplot as plt

# Can be installed via "pip install tueplots"
from tueplots import bundles

# Compute the L, R matrices from Henzinger's paper
# This is the most efficient construction I could come up with

def compute_henzinger_matrix(T):
    mat = np.zeros(shape=(T,T), dtype=float)
    entry_val = 1
    for k in range(T):
        np.fill_diagonal(mat[k:,:], entry_val)
        entry_val *= (1.0 - 0.5/(k+1))
    return mat

# Computes the variance for Henzinger's matrix mechanism by explicitly
# computing the matrix

def compute_henzinger_matrix_mechanism_variance(T, rho, do_max_variance=False):
    # This matrix M is equal to both L and R
    M = compute_henzinger_matrix(T)

    # The resulting variance from adding together noisy p-sums will be
    # nothing but the squared L2 row-norms of L, i.e. M
    variance = np.square(np.linalg.norm(M, ord=2, axis=1))

    # And the final sensitivity will be decided by the largest
    # L2 column norm of R, which has to be attained for the first column,
    # and by symmetry, the last row.
    # So variance *= variance[-1] should give the sensitivity scaling, but for
    # clarity I do it explicitly

    # do_max_variance implies we are viewing our upper bound T as varying,
    # i.e., we are computing how the maximum variance grows with T
    if do_max_variance:
        variance *= np.square(np.linalg.norm(M, ord=2, axis=1))
    else:
        variance *= np.square(np.linalg.norm(M[:,0], ord=2))

    # 0.5 / rho is to achieve rho-zCDP
    return variance * 0.5 / rho

# Given the structure of the matrix, we can pretty easily compute the variance
# per time step without actually constructing the matrix

def compute_henzinger_matrix_mechanism_variance_fast(T,
                                                     rho,
                                                     do_max_variance=False):

    # t = 1..T
    variance = np.zeros((T,), dtype=float)

    # Compute the variance contribution from adding together noisy p-sums
    entry_val = 1
    variance[0] = 1
    for t in range(1, T):
        entry_val *= (1.0 - (0.5/t))
        variance[t] = variance[t-1] + (entry_val ** 2)

    # The sensitivity is given by the max squared L2 column norm of R,
    # which is equal to the first column, equivalent to that of last row of L

    # do_max_variance implies we are viewing our upper bound T as varying,
    # i.e., we are computing how the maximum variance grows with T
    if do_max_variance:
        variance *= variance
    else:
        variance *= variance[-1]

    # 0.5 / rho is to achieve rho-zCDP
    return variance * 0.5 / rho

def compute_binary_mechanism_variance(T, rho, do_max_variance=False):

    # Compute how many terms have to be added up at each time step
    f = lambda i : bin(i).count('1')
    num_terms_added = np.array([f(i) for i in range(1, T+1)]).astype(float)

    if do_max_variance:
        # Compute the maximum number of terms that was added up
        # for each upper bound
        max_num_terms_added = np.maximum.accumulate(num_terms_added)

        # Compute the height for each upper bound on time
        h = np.ceil(np.log2(1 + np.arange(1, T+1)))

        # The maximum variance encountered for a given upper bound on time
        # is equal to the height needed to accomodate for the upper bound,
        # times the greatest number of terms that got added together up to
        # and including said upper bound
        variance = h * max_num_terms_added

    else:
        # Multiply by the final sensitivity, decided by final tree-height
        # h gives the maximum sensitivity
        h = math.ceil(math.log2(T+1))
        variance = h * num_terms_added


    # 0.5 / rho is to achieve rho-zCDP
    return variance * 0.5 / rho

def compute_our_binary_mechanism_variance(T, rho, do_max_variance=False):
    # Variance for t=1..T
    variance = np.zeros((T,), dtype=float)

    # Compute how many terms have to be added up at each time step
    # (for max variance computation) and the ultimate height needed
    # to store all elements in balanced leaves.
    i = 0
    h = 2
    while True:
        num_useful_leaves = math.comb(h, h//2) - 1
        if (num_useful_leaves > T):
            variance[i:] = h//2
            break
        else:
            variance[i:num_useful_leaves] = h//2
            i = num_useful_leaves
            h += 2

    # Multiply by the final sensitivity, decided by final tree-height
    # h//2 gives the sensitivity of every element

    # do_max_variance implies we are viewing our upper bound T as varying,
    # i.e., we are computing how the maximum variance grows with T
    if do_max_variance:
        variance *= variance
    else:
        variance = np.ones(variance.shape) * (h / 2.0) * (h / 2.0)

    # 0.5 / rho is to achieve rho-zCDP
    return variance * 0.5 / rho


# Produces a plot where we compare the running variance for different
# mechanisms when the upper bound on time is set to T, i.e., the upper bound
# on the sensitivity is fixed but the number of terms in the prefix sum
# estimates varies with t.

def running_variance_plot(save_plot=True, show_plot=True):
    rho = 1
    T = int(250)
    do_max_variance = False
    with plt.rc_context(bundles.icml2022()):
        henzinger_variance = compute_henzinger_matrix_mechanism_variance_fast(
            T, rho, do_max_variance)
        binary_variance = compute_binary_mechanism_variance(
            T, rho, do_max_variance)
        our_variance = compute_our_binary_mechanism_variance(
            T, rho, do_max_variance)

        t = np.arange(1, T+1)

        plt.figure()
        plt.plot(t, henzinger_variance, 'o', ms=2, label='Henzinger et al.')
        plt.plot(t, binary_variance, 'o', ms=2, label='Binary Mechanism')
        plt.plot(t, our_variance, 'o', ms=2, label='Our Mechanism')

        plt.xlabel('time')
        plt.ylabel(r'$\mathrm{Var}[\mathcal{M}(t)]\times\rho$')
        plt.legend()

        if save_plot:
            plt.savefig('../../figures/running_variance_comparison.pdf',
                        bbox_inches='tight')
        if show_plot:
            plt.show()
        else:
            plt.clf()

# Produces a plot where we compare the maximum variance for different
# mechanisms as a function of the upper bound on time

def max_variance_plot(save_plot=True, show_plot=True):
    rho = 1
    T = int(1e7)
    do_max_variance = True
    with plt.rc_context(bundles.icml2022()):
        henzinger_variance = compute_henzinger_matrix_mechanism_variance_fast(
            T, rho, do_max_variance)
        binary_variance = compute_binary_mechanism_variance(
            T, rho, do_max_variance)
        our_variance = compute_our_binary_mechanism_variance(
            T, rho, do_max_variance)

        t = np.arange(1, T+1)

        plt.figure()
        plt.plot(t, henzinger_variance, label='Henzinger et al.')
        plt.plot(t, binary_variance, label='Binary Mechanism')
        plt.plot(t, our_variance, label='Our Mechanism')

        print(f"Our improvement over binary mechanism at last T: {binary_variance[-1] / our_variance[-1]}")

        plt.xscale('log')
        plt.xlabel('upper bound on time')
        plt.ylabel(r'maximum $\mathrm{Var}[\mathcal{M}(t)]\times\rho$')
        plt.legend()

        if save_plot:
            plt.savefig('../../figures/max_variance_comparison.pdf',
                        bbox_inches='tight')
        if show_plot:
            plt.show()
        else:
            plt.clf()


# Sanity check that our fast computation of the variance matches the exact
# based on taking matrix norms

def henzinger_matrix_sanity_check():
    M = compute_henzinger_matrix(5)
    print(f"Henzinger's matrix looks like:\n{M}")

    T = 1000
    rho = 1
    do_max_variance = False
    matrix_based_variance = compute_henzinger_matrix_mechanism_variance(
        T, rho, do_max_variance)
    formula_based_variance = compute_henzinger_matrix_mechanism_variance_fast(
        T, rho, do_max_variance)
    running_variance_diff = np.linalg.norm(
        matrix_based_variance - formula_based_variance)

    do_max_variance = True
    matrix_based_max_variance = compute_henzinger_matrix_mechanism_variance(
        T, rho, do_max_variance)
    formula_based_max_variance = compute_henzinger_matrix_mechanism_variance_fast(
        T, rho, do_max_variance)
    max_variance_diff = np.linalg.norm(
        matrix_based_max_variance - formula_based_max_variance)

    print("Comparison between variance derived from matrix and formula")
    print(f"The running variance: ||matrix_based-formula_based||_2 = {running_variance_diff}")
    print(f"The max variance: ||matrix_based-formula_based||_2= {max_variance_diff}")

if __name__  == "__main__":

    save_plot = False
    show_plot = True

    if 'nosave' in sys.argv[1:]:
        save_plot = False
    if 'noshow' in sys.argv[1:]:
        show_plot = False
    if 'sanitycheck' in sys.argv[1:]:
        henzinger_matrix_sanity_check()

    running_variance_plot(save_plot, show_plot)
    max_variance_plot(save_plot, show_plot)

