import numpy as np
import itertools

# from generate_data import privatize_suff_stats
from model_conjugate_update import model_conjugate_update
from Gibbs import Gibbs, Gibbs_ind, Gibbs_share
import datetime as dt
from joblib import Parallel, delayed
from copy import deepcopy

def scale_k(z, k):
    z = deepcopy(z)
    for iz in z:
        for key in iz:
            iz[key] = k * iz[key]
    return z

def run_methods(data_prior_params, model_prior_params, multiS, multiZ, sigma_DP, multiN, nSample, DP_method, methods, savekl=None, k=1, num_burnin=10000, parallel=Parallel(n_jobs=-1), chains=5, indep_repeats=None):
    """
    chains is the number of chains in parallel and the thinning factor.
    """

    if not isinstance(multiZ, tuple) and not isinstance(multiZ, list): # how about list?
        multiZ = [multiZ]


    if k != 1:
        multiS = scale_k(multiS, k)
        multiZ = scale_k(multiZ, k)
        sigma_DP = k * sigma_DP
        multiN = k * multiN

    nP = len(multiZ)
    d = multiZ[0]['Xy'].shape[0] - 1
    N = np.sum(multiN)

    # aggregated sufficient statistics
    Z = {'XX': np.zeros([d+1, d+1]), 'Xy': np.zeros([d+1, 1]), 'yy': 0.0, 'X': np.zeros([d+1, 1])}

    for i in range(nP):
        Z = {key: Z[key] + val for key, val in multiZ[i].items()}

    posteriors = {}

    for method in methods:
        start_time = dt.datetime.now()

        if method=='non-private':
            if not isinstance(multiS, tuple):
                multiS = [multiS]
            S = {'XX': np.zeros([d+1, d+1]), 'Xy': np.zeros([d+1, 1]), 'yy': 0.0}
            for i in range(nP):
                S = {key: S[key] + val for key, val in multiS[i].items()}
            posteriors[method] = run_non_private(model_prior_params, S, N, nSample, savekl=savekl)

        elif method=='naive':
            posteriors[method] = run_naive(model_prior_params, Z, N, nSample)

        elif method=='gibbs-Isserlis':
            scale_DP_noise = np.sqrt(np.sum(sigma_DP**2))
            posteriors['gibbs-Isserlis'] = run_gibbs_Isserlis(Z, scale_DP_noise, data_prior_params, model_prior_params, N, nSample, num_burnin=num_burnin)

        else:
            args  = multiZ, sigma_DP, data_prior_params, model_prior_params, multiN, nSample
            kwargs = {'num_burnin': num_burnin}

            if method == 'gibbs-Isserlis-ind' or method == 'gibbs-Isserlis-ind-parallel':
                gibbs_func = run_gibbs_Isserlis_ind
            elif method == 'gibbs-Isserlis-share' or method == 'gibbs-Isserlis-share-parallel':
                gibbs_func = run_gibbs_Isserlis_share

            if method == 'gibbs-Isserlis-ind' or method == 'gibbs-Isserlis-share':
                posteriors[method] = gibbs_func(*args, **kwargs)
            elif method == 'gibbs-Isserlis-share-parallel' or method == 'gibbs-Isserlis-ind-parallel':
                if indep_repeats is not None:
                    results = parallel(delayed(gibbs_func)(*args, **kwargs) for r in range(chains * indep_repeats))
                else:
                    results = parallel(delayed(gibbs_func)(*args, **kwargs) for r in range(chains))

                concat = np.vstack([p[0] for p in results]), np.hstack([p[1] for p in results])

                if indep_repeats is None:
                    for ind in range(chains):
                        posteriors['post-{}'.format(ind)] = concat[0][ind::chains], concat[1][ind::chains]
                else:
                    thinned_split = np.array_split(concat[0][::chains], indep_repeats), np.array_split(concat[1][::chains], indep_repeats)
                    for ind in range(indep_repeats):
                        posteriors['post-{}'.format(ind)] = thinned_split[0][ind], thinned_split[1][ind]


        end_time = dt.datetime.now()
        elapsed_time = (end_time - start_time).seconds
        print('Time {}: '.format(method), elapsed_time, 's')

    return posteriors


def run_non_private(model_prior_params, S, N, nSample, savekl=None, k=1):

    theta, sigma_squared, kl = model_conjugate_update(model_prior_params, S, N, project=False, size=nSample, return_kl=True)

    return theta, sigma_squared, kl

def run_naive(model_prior_params, Z, N, nSample):

    theta, sigma_squared = model_conjugate_update(model_prior_params, Z, N, project=True, size=nSample)

    return theta, sigma_squared

def run_gibbs_Isserlis(Z, sigma_DP_noise, data_prior_params, model_prior_params, N, nSample=2000, k=1, num_burnin=2000):

    theta, sigma_squared = Gibbs(data_prior_params,
                                 model_prior_params,
                                 N,
                                 sigma_DP_noise,
                                 Z,
                                 num_burnin,
                                 nSample,
                                 'gibbs-Isserlis')

    return theta, sigma_squared

def run_gibbs_Isserlis_share(multiZ, sigma_DP_noise, data_prior_params, model_prior_params, multiN, nSample=2000, k=1, num_burnin=2000):

    theta, sigma_squared = Gibbs_share(data_prior_params,
                                 model_prior_params,
                                 multiN,
                                 sigma_DP_noise,
                                 multiZ,
                                 num_burnin,
                                 nSample,
                                 'gibbs-Isserlis-ind')

    return theta, sigma_squared

def run_gibbs_Isserlis_ind(multiZ, sigma_DP_noise, data_prior_params, model_prior_params, multiN, nSample=2000, k=1, num_burnin=2000):

    theta, sigma_squared = Gibbs_ind(data_prior_params,
                                 model_prior_params,
                                 multiN,
                                 sigma_DP_noise,
                                 multiZ,
                                 num_burnin,
                                 nSample,
                                 'gibbs-Isserlis-ind')

    return theta, sigma_squared
