from ABC import pdf
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import norm
from joblib import load

def get_priors(dist_name, epsilon = 0.25,
        n_gen = 100,
        n_iter = 10000,
        n_runs = 10,
        d=1):
    main_name = '_eps_' + str(epsilon) + '_ngen_' + str(n_gen) +  '_niter_' + str(n_iter) + '_nruns_' + str(n_runs)
    name = 'accepted_priors_' + dist_name + main_name
    accepted_priors_list  = load('save/'+ name)
    accepted_priors = [x for l in accepted_priors_list for x in l ]
    accepted_priors = np.array(accepted_priors)
    accepted_priors = accepted_priors.reshape(-1,d)
    return accepted_priors

if __name__ ==  '__main__':

    params = {"legend.fontsize": 18,
            "axes.titlesize": 16,
            "axes.labelsize": 16,
            "xtick.labelsize": 13,
            "ytick.labelsize": 13,
            "pdf.fonttype": 42,
            "svg.fonttype": 'none'}
    plt.rcParams.update(params)

    n_gen = 100
    n_iter = 10000
    n_runs = 10
    d = 1

    mmd_accepted_priors25 = get_priors('mmd',epsilon=0.25)
    dkt_accepted_priors25 = get_priors('dkt',epsilon=0.25)
    mmd_accepted_priors5 = get_priors('mmd',epsilon=0.5)
    dkt_accepted_priors5 = get_priors('dkt',epsilon=0.5)
    mmd_accepted_priors05 = get_priors('mmd',epsilon=0.05)

    nmmd_accepted_priors25 = get_priors('normalised_gaussian_mmd',epsilon=0.25)
    emmd_accepted_priors5 = get_priors('energy_mmd',epsilon=0.5)


    plt.figure(figsize=(12, 8))
    x = np.linspace(-7,
                    7, 100)
    plt.plot(x, pdf(x,dkt_accepted_priors5),
        'b-', alpha=0.6, label=r'$d_{KT} \quad \varepsilon=0.5$', lw=2)
    plt.plot(x, pdf(x,dkt_accepted_priors25),
        'c-', alpha=0.6, label=r'$d_{KT} \quad \varepsilon=0.25$', lw=2)
    plt.plot(x,norm.pdf(x,loc=1),'k',alpha=0.6,label='target pdf', lw=2, ls="--")
    plt.plot(x, pdf(x,nmmd_accepted_priors25),
        'r-', alpha=0.6, label=r'Normalised MMD  $\varepsilon=0.25$', lw=2)
    plt.plot(x, pdf(x,emmd_accepted_priors5),
        'k-', alpha=0.6, label=r'Energy MMD  $\varepsilon=0.5$', lw=2)
    plt.legend()
    plt.grid()
    plt.savefig("save/ABC_posteriors_others.pdf")
    # plt.show()