from joblib import load
import numpy as np

def get_mse(dist_name, epsilon = 0.25,
        n_gen = 100,
        n_iter = 10000,
        n_runs = 10):
    main_name = '_eps_' + str(epsilon) + '_ngen_' + str(n_gen) +  '_niter_' + str(n_iter) + '_nruns_' + str(n_runs)
    mse_name = 'MSE_' + dist_name + main_name
    MSE_list  = load('save/'+ mse_name)
    MSE_list = np.array(MSE_list)
    return MSE_list

def get_acceptances(dist_name, epsilon = 0.25,
        n_gen = 100,
        n_iter = 10000,
        n_runs = 10,
        d=1):
    main_name = '_eps_' + str(epsilon) + '_ngen_' + str(n_gen) +  '_niter_' + str(n_iter) + '_nruns_' + str(n_runs)
    name = 'accepted_priors_' + dist_name + main_name
    accepted_priors_list  = load('save/'+ name)
    length_list = [len(l) for l in accepted_priors_list]
    lengths = np.array(length_list)
    return lengths

def print_summary(dist_name,epsilon):
    MSE_list = get_mse(dist_name,epsilon=epsilon)
    MSE = MSE_list.mean()
    MSEstd = MSE_list.std()
    lengths = get_acceptances(dist_name,epsilon)
    accept_mean = lengths.mean()
    accept_std = lengths.std()
    print("distance", dist_name)
    print("epsilon",epsilon)
    print("MSE",round(MSE,2),"std",round(MSEstd,2))
    print("acceptances number",round(accept_mean,0),"std",round(accept_std,0))

for epsilon in [0.05,0.25,0.5,1.0]:
    print_summary('wasserstein',epsilon)

for epsilon in [0.05,0.25,0.5,1.0]:
    print_summary('mmd',epsilon)

for epsilon in [0.25,0.5,1.0]:
    print_summary('dkt',epsilon)

for epsilon in [0.05,0.25,0.5,1.0]:
    print_summary('normalised_gaussian_mmd',epsilon)

for epsilon in [0.05, 0.25, 0.5, 1.0]:
    print_summary('energy_mmd',epsilon)

for epsilon in [0.05, 0.25, 0.5, 1.0]:
    print_summary('gauss_KFDA',epsilon)

for epsilon in [0.05, 0.25, 0.5, 1.0]:
    print_summary('gauss_normalised_KFDA',epsilon)

for epsilon in [0.05, 0.25, 0.5, 1.0]:
    print_summary('gauss_kernel_wasserstein',epsilon)

for epsilon in [0.05, 0.25, 0.5, 1.0]:
    print_summary('gauss_KBW',epsilon)







