import argparse
import copy
import dill
import numpy as np

import os
import sys
sys.path.append(os.getcwd() + '/linreg/')

from NIG import NIG_rvs
from operator import itemgetter
from run_methods import run_methods
from evaluation import plot_posteriors, compute_mnlp_from_dataset
from generate_data import privatize_suff_stats_Gauss
from kl_estimation import compute_kl_gibbs

def main():
    parser = argparse.ArgumentParser(description="Run reward control experiment.")
    parser.add_argument("-s", "--source", help="path to dataset/setting pkl file")
    parser.add_argument("-d", "--dataprior", choices=['share', 'ind'], help="share/ind data prior")
    parser.add_argument("-c", "--chains", type=int, default=16, help="number of chains/keep every c-th sample")
    args = vars(parser.parse_args())

    source = args['source']
    dest = source.split('.')[0] + '-reward-control-results.pkl'
    method = 'gibbs-Isserlis-{}-parallel'.format(args['dataprior'])
    chains = args['chains']
    indep_repeats = 5

    with open(source, "rb") as f:
        nP, N, epsilon, DP_method, d, dataset, test_dataset, sensitivity, S, Z, sigma_DP, data_prior_params, model_prior_params, nSample = dill.load(f)

    prior = NIG_rvs(*model_prior_params, size=nSample)

    grand_posterior = run_methods(data_prior_params, model_prior_params, tuple(S), tuple(Z),
                                  sigma_DP, N, nSample, DP_method, [method], chains=chains)

    print("Original MNLP", compute_mnlp_from_dataset(test_dataset, grand_posterior['post-0']))

    ks = [1,0.8,0.6,0.4,0.2,0.1,0.05,0.025,0.01,0.005,0.0025,0.001]
    k_kls = np.zeros((len(ks), indep_repeats))
    k_post_kls =  np.zeros((len(ks), indep_repeats))
    k_mnlps = np.zeros((len(ks), indep_repeats))

    ## Scaling factor/kappa experiments
    for e, k in enumerate(ks):
        posteriors = run_methods(data_prior_params, model_prior_params, tuple(S), tuple(Z),
                                 sigma_DP, N, nSample, DP_method, [method], k=k, chains=chains,
                                 indep_repeats=indep_repeats)
        posteriors['grand'] = grand_posterior['post-0']

        for r in range(indep_repeats):
            key = 'post-{}'.format(r)
            posteriors['prior'] = prior

            k_mnlps[e, r] = compute_mnlp_from_dataset(test_dataset, posteriors[key])
            k_kls[e, r] = compute_kl_gibbs(posteriors, key, mode='nearestc')

            posteriors['prior'] = posteriors[key]
            k_post_kls[e, r] = compute_kl_gibbs(posteriors, 'grand', mode='nearestc')

        print("SS scaling factor", k)
        print("   KL with prior, original posterior, MNLP", k_kls[e,:].mean(), k_post_kls[e,:].mean(), k_mnlps[e,:].mean())

    ### noise addition/tau experiments
    add_scaled_noise = [1,2,4,8,16,32,64,128,256,512,1024]
    h_kls = np.zeros((len(add_scaled_noise), indep_repeats))
    h_post_kls =  np.zeros((len(add_scaled_noise), indep_repeats))
    h_mnlps = np.zeros((len(add_scaled_noise), indep_repeats))

    for e, scaled in enumerate(add_scaled_noise):
        eps = 1. / scaled
        new_Z = copy.deepcopy(Z)
        add_sigmas = np.zeros(nP)
        for j in range(nP):
            # same additional eps across all parties, further privatizing Z
            # syn seed: 9 + 14j, cali seed: 1 + 4*j
            new_Z[j], add_sigmas[j] = privatize_suff_stats_Gauss(copy.copy(Z[j]), sensitivity, eps, seed= 1+4*j)
        new_sigma_DP = np.sqrt(add_sigmas**2 + sigma_DP**2)
        np.random.seed()

        posteriors = run_methods(data_prior_params, model_prior_params, tuple(S), tuple(new_Z),
                                 new_sigma_DP, N, nSample, DP_method, [method], chains=chains,
                                 indep_repeats=indep_repeats)

        posteriors['grand'] = grand_posterior['post-0']

        for r in range(indep_repeats):
            key = 'post-{}'.format(r)
            posteriors['prior'] = prior

            h_mnlps[e, r] = compute_mnlp_from_dataset(test_dataset, posteriors[key])
            h_kls[e, r] = compute_kl_gibbs(posteriors, key, mode='nearestc')

            posteriors['prior'] = posteriors[key]
            h_post_kls[e, r] = compute_kl_gibbs(posteriors, 'grand', mode='nearestc')

        print("Injected noise scaled variance", scaled)
        print("   KL with prior, original posterior, MNLP", h_kls[e,:].mean(), h_post_kls[e,:].mean(), h_mnlps[e,:].mean())

    with open(dest, "wb") as f:
        dill.dump((ks, k_kls, k_post_kls, k_mnlps), f)
        dill.dump((add_scaled_noise, h_kls, h_post_kls, h_mnlps), f)

if __name__ == '__main__':
    main()
