import argparse
import dill
import numpy as np
from copy import copy, deepcopy

import os
import sys
sys.path.append(os.getcwd() + '/logreg/')

from log_generate_data import  generate_approx_ss_noisy, generate_data_config, privatize_config
from kl_estimation import compute_kl_log
from cmdstanpy import CmdStanModel

from scipy.special import expit
from sklearn.metrics import log_loss
import arviz as az

from shapley import LogValuation
from rewardk import LogRewardProtocol

def main():
    parser = argparse.ArgumentParser(description="Run reward control experiment.")
    parser.add_argument("-s", "--source", help="path to dataset/setting pkl file")
    parser.add_argument("-c", "--chains", type=int, default=25, help="number of chains shared across indep repeats")
    parser.add_argument("-v", "--varyparty", type=int, default=1, help="vary eps of party ?")
    parser.add_argument("-p", "--rho", type=float, default=0.2, help="rho in rho-Shapley value")
    parser.add_argument("-r", "--solverewards", action='store_true', default=False,
                        help="flag to solve reward for specific party")
    args = vars(parser.parse_args())

    source = args['source']
    dest = source.split('.')[0] + '-log-private-mono-results.pkl'
    chains = args['chains']
    vary_party_index = args['varyparty']
    solve = args['solverewards']
    rho = args['rho']


    with open(source, "rb") as f:
         nP, N, epsilons, DP_method, d, multiX, multiY, R, GS, l2_norm, pass_object, approx_ss, perturbed_ss, variances = dill.load(f)
         X_train, X_test, y_train, y_test = dill.load(f)

    eps_scale_factor = [1./125,1./25,1./5,1,5,25]
    eps_vary_party = epsilons[vary_party_index] / np.array(eps_scale_factor)
    noise_repeats = 5

    model_file = 'logreg/logistic_regression_models/LDP_logistic_regression_ss_based.stan'

    def compute_log_loss_from_posterior(split):
        test_pred = expit(split @ X_test.T)
        test_pred = test_pred.mean(1) #mean of monte carlo lead to non-linear contours
        log_losses = np.array([log_loss(y_test, c) for c in test_pred])
        return log_losses.mean()

    noise_repeats = 5
    v_is = np.zeros((len(eps_scale_factor), noise_repeats))
    v_ns = np.zeros((len(eps_scale_factor), noise_repeats))
    phi_is = np.zeros((len(eps_scale_factor), noise_repeats))
    ri_is = np.zeros((len(eps_scale_factor), noise_repeats))
    loss_ns = np.zeros((len(eps_scale_factor), noise_repeats))
    loss_is = np.zeros_like(loss_ns)
    r_valuations = []

    for r in range(noise_repeats):
        print("-----",r,"-----")
        # consider different realization of noise to privatize Z[vary_party_index]
        Z = perturbed_ss.copy()
        if r != 0:
            _, _, Z[vary_party_index], _ = generate_approx_ss_noisy(multiX[vary_party_index], multiY[vary_party_index],
                                                                    epsilons[vary_party_index],
                                                                    GS, pass_object, seed=3*r)

        valuation = LogValuation(model_file, pass_object, l2_norm, d, nP, N, approx_ss, Z, variances, burn_in=400, chains=chains, split_num=noise_repeats)
        r_valuation = []

        for e, scale_factor in enumerate(eps_scale_factor):
            print("---scale factor: ",scale_factor,"---")
            if scale_factor != 1:
                v = deepcopy(valuation)
                v.update_i(vary_party_index, scale_factor)
            else:
                v = valuation
            r_valuation.append(v)

            v_is[e, r] = v.each_kls[vary_party_index]
            v_ns[e, r] = v.grand_kl
            phi_is[e, r] = v.shapleys[vary_party_index]
            loss_ns[e, r] = compute_log_loss_from_posterior(v.all_posteriors[''.join(str(x) for x in range(nP))])
            loss_is[e, r] = compute_log_loss_from_posterior(v.all_posteriors[str(vary_party_index)])

            ri_is[e, r] = v.get_reward(rho)[vary_party_index]

        r_valuations.append(r_valuation)


        with open(dest, "wb") as f:
            dill.dump((v_is, v_ns, phi_is, ri_is, loss_is, loss_ns, noise_repeats, vary_party_index, eps_vary_party), f)

    if not solve:
        return

    dest = source.split('.')[0] + '-log-private-mono-reward-results.pkl'
    loss_rs = np.zeros_like(loss_ns)
    r_rewards = []

    for r, r_valuation in enumerate(r_valuations):
        print("-----",r,"-----")
        r_reward = []
        for e, v in enumerate(r_valuation):
            print("...", eps_scale_factor[e], "...")
            re = LogRewardProtocol(v, rho)
            r_reward.append(re)
            try:
                r_posterior =  re._generate_reward_for_i(vary_party_index)
                loss_rs[e, r] =  compute_log_loss_from_posterior(r_posterior)
            except ValueError:
                print("Warning: Unsuccessful for run, eps scale factor", r, e)

        r_rewards.append(r_reward)

        with open(dest, "wb") as f:
            dill.dump((v_is, v_ns, phi_is, ri_is, loss_is, loss_rs, loss_ns, noise_repeats, vary_party_index , eps_vary_party), f)


if __name__ == '__main__':
    main()
