import argparse
import copy
import dill
import numpy as np

import os
import sys
sys.path.append(os.getcwd() + '/logreg/')

from log_generate_data import generate_data_config, privatize_config
from kl_estimation import compute_kl_log
from cmdstanpy import CmdStanModel

from scipy.special import expit
from sklearn.metrics import log_loss
import arviz as az

def main():
    parser = argparse.ArgumentParser(description="Run reward control experiment.")
    parser.add_argument("-s", "--source", help="path to dataset/setting pkl file")
    parser.add_argument("-c", "--chains", type=int, default=25, help="number of chains shared across indep repeats")
    args = vars(parser.parse_args())

    source = args['source']
    dest = source.split('.')[0] + '-log-reward-control-results.pkl'
    chains = args['chains']
    indep_repeats = 5

    with open(source, "rb") as f:
         nP, N, epsilons, DP_method, d, multiX, multiY, R, GS, l2_norm, pass_object, approx_ss, perturbed_ss, variances = dill.load(f)
         X_train, X_test, y_train, y_test = dill.load(f)

    model = CmdStanModel(stan_file='logreg/logistic_regression_models/LDP_logistic_regression_ss_based.stan')

    ## Scaling factor/kappa experiments
    fits = []
    ks = [1.0, 0.8, 0.6, 0.4,0.2,0.1,0.05,0.025, 0.01,0.005,0.001]
    for k in ks:
        new_config = generate_data_config(N, d, nP, l2_norm, pass_object, perturbed_ss, variances, k=k)
        fit = model.sample(data=new_config, iter_warmup=400, iter_sampling=2000, chains=chains, adapt_delta=0.86)
        fits.append(fit)

    ### noise addition/tau experiments
    private_fits = [fits[0]]
    config = generate_data_config(N, d, nP, l2_norm, pass_object, perturbed_ss, variances)
    hs = [1,2,4,8,12,16,32,64,128,256]
    for h in hs:
        new_config =  privatize_config(config, h, GS, pass_object, base_seed=4848)
        fit = model.sample(data=new_config, iter_warmup=400, iter_sampling=2000, chains=chains, adapt_delta=0.86)
        private_fits.append(fit)

    # the reference distribution/grand coalition posterior
    all_az_data = az.from_cmdstanpy(posterior=fits[0])
    all_samples = all_az_data.sel(chain=az.bfmi(all_az_data) > 0.3).stack(sample=["draw","chain"]).posterior.theta_DP_scaled.values.T

    def compute_kl_and_log_loss(fit, split_num=indep_repeats, verbose=False):
        az_data = az.from_cmdstanpy(posterior=fit)
        fil_az_data = az_data.sel(chain=az.bfmi(az_data) > 0.3)
        if verbose:
            with np.printoptions(precision=4, suppress=True):
                print("---BFMI", az.bfmi(az_data))
                ess = az.ess(fil_az_data)['theta_DP_scaled'].values
                print("---ESS", ess.min(), ess.mean())
                print("---Diverging", fil_az_data.sample_stats.diverging.sum(axis=1).values.mean())

        flattened_theta = fil_az_data.stack(sample=["draw","chain"]).posterior.theta_DP_scaled.values.T
        split = np.array_split(flattened_theta, split_num)
        kls = np.array([compute_kl_log(samples) for samples in split])

        test_pred = expit(split @ X_test.T) #num group chain, num samples, num data point
        test_pred = test_pred.mean(1) #mean of monte carlo lead to non-linear contours
        log_losses = np.array([log_loss(y_test, c) for c in test_pred])

        post_kls = np.array([compute_kl_log(samples, against_posterior=all_samples) for samples in split])

        return kls, log_losses, post_kls


    k_kls = np.zeros((len(ks), indep_repeats))
    k_loss = np.zeros((len(ks), indep_repeats))
    k_post_kls = np.zeros((len(ks), indep_repeats))
    for i in range(len(ks)):
        k_kls[i], k_loss[i], k_post_kls[i] = compute_kl_and_log_loss(fits[i])


    h_kls = np.zeros((len(hs), indep_repeats))
    h_loss = np.zeros((len(hs), indep_repeats))
    h_post_kls = np.zeros((len(hs), indep_repeats))
    for i in range(len(hs)):
        h_kls[i], h_loss[i], h_post_kls[i] = compute_kl_and_log_loss(private_fits[i])

    with np.printoptions(precision=4, suppress=True):
        print("KL, vary kappa\n", k_kls)
        print("Log loss, vary kappa\n", k_loss)

        print("\n\nKL, vary tau\n", h_kls)
        print("Log loss, vary tau\n", h_loss)

    with open(dest, "wb") as f:
        dill.dump((ks, k_kls, k_post_kls, k_loss), f)
        dill.dump((hs, h_kls, h_post_kls, h_loss), f)


if __name__ == '__main__':
    main()
