import argparse
import dill
import numpy as np
from copy import copy, deepcopy

import os
import sys
sys.path.append(os.getcwd() + '/linreg/')

from NIG import NIG_rvs
from operator import itemgetter
from run_methods import run_methods
from evaluation import plot_posteriors, compute_mnlp_from_dataset
from generate_data import privatize_suff_stats_Gauss
from kl_estimation import compute_kl_gibbs

from shapley import Valuation
from rewardk import RewardProtocol


def main():
    parser = argparse.ArgumentParser(description="Run privacy monotonicity experiment.")
    parser.add_argument("-s", "--source", help="path to dataset/setting pkl file")
    parser.add_argument("-d", "--dataprior", choices=['share', 'ind'], help="share/ind data prior")
    parser.add_argument("-c", "--chains", type=int, default=16, help="number of chains/keep every c-th sample")
    parser.add_argument("-v", "--varyparty", type=int, default=1, help="vary eps of party ?")
    parser.add_argument("-p", "--rho", type=float, default=0.2, help="rho in rho-Shapley value")
    parser.add_argument("-r", "--solverewards", action='store_true', default=False,
                        help="flag to solve reward for specific party")
    args = vars(parser.parse_args())

    source = args['source']
    dest = source.split('.')[0] + '-private-mono-results.pkl'
    method = 'gibbs-Isserlis-{}-parallel'.format(args['dataprior'])
    chains = args['chains']
    vary_party_index = args['varyparty']
    solve = args['solverewards']
    rho = args['rho']

    with open(source, "rb") as f:
        nP, N, epsilon, DP_method, d, dataset, test_dataset, sensitivity, S, Z, sigma_DP, data_prior_params, model_prior_params, nSample = dill.load(f)

    eps_scale_factor = [1./125,1./25,1./5,1,5,25]
    eps_vary_party = epsilon[vary_party_index] / np.array(eps_scale_factor)
    noise_repeats = 5

    v_is = np.zeros((len(eps_scale_factor), noise_repeats))
    v_ns = np.zeros((len(eps_scale_factor), noise_repeats))
    phi_is = np.zeros((len(eps_scale_factor), noise_repeats))
    ri_is = np.zeros((len(eps_scale_factor), noise_repeats))
    mnlp_ns = np.zeros((len(eps_scale_factor), noise_repeats))
    mnlp_is = np.zeros((len(eps_scale_factor), noise_repeats))
    r_valuations = []


    for r in range(noise_repeats):
        print("-----",r,"-----")
        # consider different realization of noise to privatize Z[vary_party_index]
        new_Z = Z.copy()
        if r != 0:
            new_Z[vary_party_index], _ = privatize_suff_stats_Gauss(S[vary_party_index], sensitivity,
                                                                epsilon[vary_party_index], seed=3*r)

        valuation = Valuation(data_prior_params, model_prior_params, nP, S, new_Z, sigma_DP, N, nSample=nSample, inf_method=method, mode='nearestc', chains=chains)

        r_valuation = []

        for e, scale_factor in enumerate(eps_scale_factor):
            print("---scale factor: ",scale_factor,"---")
            if scale_factor != 1:
                v = copy(valuation)
                v.update_i(vary_party_index, scale_factor)
            else:
                v = valuation
            r_valuation.append(v)

            v_is[e, r] = v.each_kls[vary_party_index]
            v_ns[e, r] = v.grand_kl
            phi_is[e, r] = v.shapleys[vary_party_index]
            ri_is[e, r] = v.get_reward(rho)[vary_party_index]
            mnlp_ns[e, r],  _ = compute_mnlps_from_dataset(test_dataset,
                                                      v.all_posteriors[''.join(str(x) for x in range(nP))])
            mnlp_is[e, r],  _ = compute_mnlps_from_dataset(test_dataset,
                                                      v.all_posteriors[str(vary_party_index)])

        r_valuations.append(r_valuation)

        with open(dest, "wb") as f:
            dill.dump((v_is, v_ns, phi_is, ri_is, mnlp_ns, mnlp_is, noise_repeats, vary_party_index, eps_vary_party), f)

    if not solve:
        return

    dest = source.split('.')[0] + '-private-mono-reward-results.pkl'
    mnlp_rs = np.zeros((len(eps_scale_factor), noise_repeats))
    r_rewards = []

    for r, r_valuation in enumerate(r_valuations):
        print("-----",r,"-----")
        r_reward = []
        for e, v in enumerate(r_valuation):
            print("...", eps_scale_factor[e], "...")
            re = RewardProtocol(v, rho, sensitivity)
            r_reward.append(re)
            try:
                r_posterior =  re._generate_reward_for_i(vary_party_index)
                mnlp_rs[e, r], rest = compute_mnlps_from_dataset(test_dataset, r_posterior)
            except ValueError:
                print("Warning: Unsuccessful for run, eps scale factor", r, e)

        r_rewards.append(r_reward)

        with open(dest, "wb") as f:
            dill.dump((v_is, v_ns, phi_is, ri_is, mnlp_ns, mnlp_rs, mnlp_is, noise_repeats, vary_party_index, eps_vary_party), f)


if __name__ == '__main__':
    main()
