import numpy as np
from scipy.optimize import bisect
from scipy.special import binom

from utils import set_seeds
from skellam_utils import SkellamMechanismPyTorch as SkellamMechanism

def amplified_RDP(eps_rdp_list:list, alpha:int, q:float):
    """
    RDP amplification bound for Poisson sampling and add/remove neighbourhood, from https://arxiv.org/abs/2210.00597
    Args:
        eps_rdp_list: list of RDP epsilons corresponding to alpha=2,3,...,max_alpha
        alpha: integer >= 2
        q: Poisson sampling probability
    """
    assert alpha >= 2, f'alpha must be >= 2, got {alpha}'
    assert len(eps_rdp_list) >= alpha-1, f'eps_rdp_list must have length at least alpha-1, got {len(eps_rdp_list)} with alpha={alpha}'
    term0 = (1-q)**(alpha-1) * (1+(alpha-1)*q)
    term1 = sum([ binom(alpha, j)* (1-q)**(alpha-j) * q**j * np.exp( (j-1)*eps_rdp_list[j-2] )  for j in range(2,alpha+1) ])
    return 1/(alpha-1)*np.log((term0+term1))

def from_RDP_to_DP(eps_rdp, alpha, target_delta):
    # original Mironov 2017 conversion: Prop3 in https://arxiv.org/abs/1702.07476
    return eps_rdp-np.log(target_delta)/(alpha-1)

def get_skellam_noise_multiplier(quantization:int, num_params:int, num_clients:int, alphas:list, sampling_frac:float, n_comps:int, target_eps:float, target_delta:float):
    C = 1.0 # assume norm bound = 1 for accounting
    # do binary search to find scale
    print('Running binary search for Skellam noise scale...')
    prop_scale = bisect(f=skellam_search_fun, a=.5, b=50., args=(alphas, quantization, num_params, num_clients, sampling_frac, n_comps, target_delta, target_eps, C))
    mu = (C*prop_scale)**2
    skellam = SkellamMechanism(quantization, d=num_params, norm_bound=C, mu=mu, device='cpu', num_clients=num_clients)
    print(f'Found noise scale={prop_scale}, which gives delta={get_skellam_adp(skellam, alphas, sampling_frac, n_comps, target_delta)}')
    return prop_scale
    
def skellam_search_fun(scale, alphas, quantization, num_params, num_clients, sampling_frac, n_comps, target_delta, target_eps, C=1.0):
    mu = (C*scale)**2
    skellam = SkellamMechanism(quantization, d=num_params, norm_bound=C, mu=mu, device='cpu', num_clients=num_clients)
    min_eps = get_skellam_adp(skellam, alphas, sampling_frac, n_comps, target_delta)
    return min_eps - target_eps

def get_skellam_adp(skellam, alphas, sampling_frac, n_comps, target_delta):
    amplified_RDP_eps, i_alpha_ = get_skellam_rdp(skellam, alphas, sampling_frac, n_comps)
    min_eps = np.inf
    for i_alpha, max_alpha_ in enumerate(alphas[:i_alpha_+1]):
        total_eps = from_RDP_to_DP(amplified_RDP_eps[i_alpha], max_alpha_, target_delta)
        min_eps = min(min_eps, total_eps)
    return min_eps

def get_skellam_rdp(skellam, alphas, sampling_frac, n_comps):
    RDP_eps = skellam.renyi_div(alphas)
    amplified_RDP_eps = np.zeros(len(alphas))
    for i_alpha, max_alpha_ in enumerate(alphas):
        tmp = amplified_RDP(RDP_eps, max_alpha_, sampling_frac)
        amplified_RDP_eps[i_alpha] = n_comps * tmp
        if not np.isfinite(tmp):
            break
    return amplified_RDP_eps, i_alpha

def run_skellam_search(quantization, num_params, num_clients, sampling_frac, n_skellam_comps, target_eps, target_delta):
    alphas = np.array(list(range(2, 128)))
    skellam_noise_sigma_ = get_skellam_noise_multiplier(quantization, num_params, num_clients, alphas, sampling_frac, n_comps=n_skellam_comps, target_eps=target_eps, target_delta=target_delta)
    print(f'\ncheck per-client value:\n{skellam_noise_sigma_/np.sqrt(num_clients)}')
    alphas = np.array(list(range(2, 1024)))
    skellam_noise_sigma = get_skellam_noise_multiplier(quantization, num_params, num_clients, alphas, sampling_frac, n_skellam_comps, target_eps, target_delta)
    print(f'For ({target_eps},{target_delta})-ADP, {n_skellam_comps} compositions with sample subsampling ratio {sampling_frac}, got per client scale:\n{skellam_noise_sigma/np.sqrt(num_clients)} for Skellam noise assuming secsum with {num_clients} clients')
    if not np.allclose(skellam_noise_sigma_, skellam_noise_sigma):
        print('NOTE: CHECK ACCOUNTING ALPHAS!')
    scale = skellam_noise_sigma
    mu = scale**2
    skellam = SkellamMechanism(
        quantization, d=num_params, norm_bound=1.0, mu=mu, device='cpu', num_clients=num_clients
        )
    prop_eps = get_skellam_adp(skellam, alphas, sampling_frac, n_skellam_comps, target_delta)
    print('\nChecking proposed noise level, got eps:', prop_eps, ' with delta:', target_delta, ' using per-client scale:', skellam_noise_sigma/np.sqrt(num_clients))
    return skellam_noise_sigma

if __name__ == '__main__':
    
    set_seeds(2303)
    n_gaussian_comps = 0

    if 1:
        
        # for the experiments in the paper, uncomment the corresponding configuration
        ##################################################
        ##### Experiment 1 with skellam noise
        ##################################################
        # Fashion MNIST, 10 clients, .8 train-.2 test data leaves 5600 train samples per client
        num_clients = 10
        target_delta = 1e-5
        target_eps = 1. # this is total targeted ADP eps
        num_params = 26010 # for fashion MNIST CNN
        quantization = 32

        # batch size : 512
        sampling_frac = 512/5600
        n_skellam_comps = 20 # 1 local step
        #n_skellam_comps = 20*11 # 1 local epoch

        # batch size : 256
        #sampling_frac = 256/5600
        #n_skellam_comps = 20 # 1 local step
        #n_skellam_comps = 20*22 # 1 local epoch

        # batch size : 128
        #sampling_frac = 128/5600
        #n_skellam_comps = 20 # 1 local step
        #n_skellam_comps = 20*44 # 1 local epoch

        # batch size : 64
        #sampling_frac = 64/5600
        #n_skellam_comps = 20 # 1 local step
        #n_skellam_comps = 20*88 # 1 local epoch

        ##################################################
        ##### Experiment 2 with skellam noise
        ##################################################
        # CIFAR10 pre-trained features, 10 clients, .8 train-.2 test data leaves 4800 train samples per client
        # baseline with 20 FL rounds
        #num_clients = 10
        #target_delta = 1e-5
        #target_eps = 1. # this is total targeted ADP eps
        #num_params = 10250 # linear model for CIFAR10 with pretrained resnext29 features
        #quantization = 32
        
        #sampling_frac = 512/4800
        #n_skellam_comps = 20 # 1 local step, all batch sizes
        #n_skellam_comps = 20*10 # 1 local epoch, 512/4800

        #sampling_frac = 256/4800 
        #n_skellam_comps = 20 # 1 local step, all batch sizes
        #n_skellam_comps = 20*10 # 10 local steps, 256/4800
        #n_skellam_comps = 20*19 # 1 local epoch, 256/4800
        
        #sampling_frac = 128/4800
        #n_skellam_comps = 20 # 1 local step, all batch sizes
        #n_skellam_comps = 20*38 # 1 local epoch

        #sampling_frac = 64/4800
        #n_skellam_comps = 20 # 1 local step
        #n_skellam_comps = 20*75 # 1 local epoch

        ##################################################
        ##### Experiment 3: additional testing with FL rounds
        ##################################################
        # CIFAR10 pre-trained features, 10 clients, .8 train-.2 test data leaves 4800 train samples per client
        # 10 and 40 FL rounds
        #num_clients = 10
        #target_delta = 1e-5
        #target_eps = 1. # this is total targeted ADP eps
        #num_params = 10250 # linear model for CIFAR10 with pretrained resnext29 features
        #quantization = 32

        # keep batch size fixed to optimal value found under 20 FL rounds
        # 1 step: 
        #sampling_frac = 128/4800
        #n_skellam_comps = 10 # 1 local step
        #n_skellam_comps = 40 # 1 local step
        #n_skellam_comps = 80 # 1 local step
        #n_skellam_comps = 160 # 1 local step

        # 1 epoch:
        #sampling_frac = 256/4800
        #n_skellam_comps = 10*19 # 1 local epoch
        #n_skellam_comps = 40*19 # 1 local epoch
        #n_skellam_comps = 80*19 # 1 local epoch
        #n_skellam_comps = 160*19 # 1 local epoch
        
        ##################################################
        ##### Experiment 4: Income data classification
        ##################################################
        # inherent data split with 51 clients, .8 train-.2 test data
        # 10 FL rounds
        #num_clients = 51
        #target_delta = 1e-5
        #target_eps = 1. # this is total targeted ADP eps
        #num_params = 53 # classifier network for income data
        #quantization = 32

        #sampling_frac = .4
        #sampling_frac = .2
        #sampling_frac = .1
        #sampling_frac = .05

        #n_skellam_comps = 10 # 10 FL rounds, 1 step
        #n_skellam_comps = int(10*1/sampling_frac) # 10 FL rounds, 1 epoch
        
        # end of setup
        ##################################################
        print('Using', n_skellam_comps, 'local steps with client subsampling ratio', sampling_frac)
        ##################################################
        assert n_gaussian_comps == 0
        gaussian_ADP_share = 0.
        tmp_skellam_eps = target_eps
        tmp_delta = target_delta
        # increase alpha if the code complains
        alphas = np.array(list(range(2, 128)))
        #alphas = np.array(list(range(2, 512)))
        #alphas = np.array(list(range(2, 1024)))
        skellam_noise_sigma_ = get_skellam_noise_multiplier(quantization, num_params, num_clients, alphas, sampling_frac, n_comps=n_skellam_comps, target_eps=tmp_skellam_eps, target_delta=tmp_delta)
        print(f'\ncheck per-client value:\n{skellam_noise_sigma_/np.sqrt(num_clients)}')
        alphas = np.array(list(range(2, 1024)))
        skellam_noise_sigma = get_skellam_noise_multiplier(quantization, num_params, num_clients, alphas, sampling_frac, n_skellam_comps, tmp_skellam_eps, tmp_delta)
        print(f'For ({tmp_skellam_eps},{tmp_delta})-ADP, {n_skellam_comps} compositions with local subsampling ratio {sampling_frac}, got per client scale:\n{skellam_noise_sigma/np.sqrt(num_clients)} for Skellam noise assuming secsum with {num_clients} clients')
        if not np.allclose(skellam_noise_sigma_, skellam_noise_sigma):
            print('NOTE: CHECK ACCOUNTING ALPHAS!')

        # for checking resulting eps with found noise level
        scale = skellam_noise_sigma
        mu = scale**2
        skellam = SkellamMechanism(
            quantization, d=num_params, norm_bound=1.0, mu=mu, device='cpu', num_clients=num_clients
            )
        prop_eps = get_skellam_adp(skellam, alphas, sampling_frac, n_skellam_comps, target_delta)
        print('\nChecking proposed noise level, got eps:', prop_eps, ' with delta:', target_delta, ' using per-client scale:', skellam_noise_sigma/np.sqrt(num_clients))

    if 0:
        ##################################################
        ##### Experiment 5: 1 FL round model avg
        ##################################################
        # 1 FL rounds, use Income data model
        all_num_clients = [1,2,5,10]
        target_delta = 1e-5
        target_eps = 5. # this is total targeted ADP eps
        num_params = 53 # classifier network for income data
        quantization = 32
        sampling_frac = .1

        all_skellam_comps = [1, int(1/sampling_frac), 5*int(1/sampling_frac)] # 1 step, 1 epoch, 5 epochs
        all_res = np.zeros((2,len(all_skellam_comps), len(all_num_clients)))
        ##################################################
        for i_comp, n_skellam_comps in enumerate(all_skellam_comps):
            single_scale = run_skellam_search(quantization=quantization, num_params=num_params, num_clients=1, sampling_frac=sampling_frac, n_skellam_comps=n_skellam_comps, target_eps=target_eps, target_delta=target_delta)
            for num_clients in all_num_clients:
                if num_clients == 1:
                    all_res[0,i_comp,all_num_clients.index(num_clients)] = single_scale
                    all_res[1,i_comp,all_num_clients.index(num_clients)] = target_eps
                    continue
                
                scale = np.sqrt(num_clients)*single_scale
                mu = scale**2
                skellam = SkellamMechanism(
                quantization, d=num_params, norm_bound=1.0, mu=mu, device='cpu', num_clients=num_clients
                )
                alphas = np.array(list(range(2, 128)))
                prop_eps_ = get_skellam_adp(skellam, alphas, sampling_frac, n_skellam_comps, target_delta)
                alphas = np.array(list(range(2, 1024)))
                prop_eps = get_skellam_adp(skellam, alphas, sampling_frac, n_skellam_comps, target_delta)
                print(num_clients, ' clients, found eps:', prop_eps, ' with delta:', target_delta, ' using scale:', scale)
                if not np.allclose(prop_eps, prop_eps_):
                    raise ValueError('ERROR: DIFFERENT EPSILON VALUES, check alphas!')
                all_res[0,i_comp,all_num_clients.index(num_clients)] = scale
                all_res[1,i_comp,all_num_clients.index(num_clients)] = prop_eps
            
        for i_comp, n_skellam_comps in enumerate(all_skellam_comps):
            print(f'\nFor {n_skellam_comps} local steps:')
            print('Noise level per client:', np.round(all_res[0,i_comp,:],2) )
            print('Resulting eps:', np.round(all_res[1,i_comp,:],2) )