import numpy as np
import math
import matplotlib.pyplot as plt
import scipy
import pandas as pd

from tueplots import bundles

def count_good_binary_leaves(height):
    return math.comb(height, height // 2)

def compute_simple_binary_mechanism_space_needed(T):
    h = 1
    t = [1]
    space_needed = []
    while True:
        current_space_needed = max(h - 1, 1)
        space_needed.append(current_space_needed)

        if len(t) > 1:
            t.append(t[-1])
        space_needed.append(current_space_needed)
        num_available_leaves = 2**h
        if num_available_leaves > T + 1:
            t.append(T)
            break

        t.append(num_available_leaves)
        h += 1

    return t, space_needed

def compute_our_binary_mechanism_space_needed(T):
    h = 2
    t = [1]
    space_needed = []
    while True:
        current_space_needed = h // 2
        space_needed.append(current_space_needed)

        if len(t) > 1:
            t.append(t[-1])
        space_needed.append(current_space_needed)
        num_available_leaves = count_good_binary_leaves(h)
        if num_available_leaves > T + 1:
            t.append(T)
            break

        t.append(num_available_leaves)
        h += 2

    return t, space_needed

# Making a function for sake the of systematicity
# The constant in the scaling is at least 1 (we need to output T values).
def compute_henzinger_space_needed(T, dt):
    t = np.arange(1, T, dt)
    return t, t


def compare_space_needed(save_plot=True):
    T = 10**6
    dt = 100

    t_our, space_our = compute_our_binary_mechanism_space_needed(T)
    t_binary, space_binary = compute_simple_binary_mechanism_space_needed(T)
    t_henzinger, space_henzinger = compute_henzinger_space_needed(T, dt=dt)

    # print('binary')
    # print(list(zip(t_binary, space_binary)))

    # print('smooth binary')
    # print(list(zip(t_our, space_our)))

    with plt.rc_context(bundles.icml2022()):
        plt.figure()
        plt.plot(t_henzinger, space_henzinger, label='Henzinger et al.')
        plt.plot(t_binary, space_binary, label='Binary Mechanism')
        plt.plot(t_our, space_our, label='Our Mechanism')
        plt.legend()
        plt.xscale('log')
        plt.yscale('log')

        plt.xlabel('upper bound on time')
        plt.ylabel('max \# of floats needed to be stored')

        if save_plot:
            plt.savefig('../../figures/space_comparison.pdf',
                        bbox_inches='tight')

        plt.show()

if __name__ == "__main__":
    compare_space_needed(save_plot=False)

