from ABC import abc_posterior, normal_prior_sampling, gaussian_generation
import torch
import numpy as np
from distances import gauss_dkt, normalised_gaussian_MMD, energy_mmd, gauss_KBW, gauss_KFDA, gauss_normalised_KFDA, gauss_kernel_wasserstein
import ot
from geomloss import SamplesLoss
from functools import partial
import time
from joblib import dump

n = 100
d = 1

def generate(seed,n=90,d=1,n_conta=10,mean_cont=20,std_conta=1):
    torch.manual_seed(seed)
    data = 1 + torch.randn(n,d)
    n_conta = 10
    mean_conta = torch.ones(d)*mean_cont
    conta = mean_conta + std_conta*torch.randn(n_conta,d)
    data = torch.cat((data,conta),0) 
    return data


mmd = SamplesLoss("gaussian", blur=d)
normalised_gaussian_mmd = partial(normalised_gaussian_MMD,sigma=d)
dkt = partial(gauss_dkt,sigma=d)
gamma = 0.1
kfda = partial(gauss_KFDA,gamma)
norm_kfda = partial(gauss_normalised_KFDA,gamma)


prior_sample_func = normal_prior_sampling
generating_data_func = gaussian_generation
n_gen = 100
n_iter = 10000
n_runs = 10

epsilon_list = [0.05, 0.25, 0.5, 1.]

for dist, dist_name, epsilon_list in zip([dkt, mmd, normalised_gaussian_mmd, energy_mmd,ot.emd2_1d,gauss_kernel_wasserstein,gauss_KBW, kfda, norm_kfda], ['dkt', 'mmd','normalised_gaussian_mmd','energy_mmd','wasserstein','gauss_kernel_wasserstein','gauss_KBW','gauss_KFDA', 'gauss_normalised_KFDA'], [epsilon_list, epsilon_list,epsilon_list,epsilon_list,]):
    for epsilon in epsilon_list:
        print(f"Doing epsilon {epsilon}")
        MSE_list = []
        accepted_priors_list = []
        time_list = []
        for r in range(n_runs):
            data = generate(r)
            t0 = time.time()
            accepted_priors,accepted_generated = abc_posterior(dist,data,prior_sample_func,generating_data_func,n_gen,epsilon,n_iter)
            total_time = time.time() - t0
            time_list.append(total_time)

            accepted_priors = np.array([x.numpy() for x in accepted_priors])
            accepted_priors = accepted_priors.reshape(-1,d)
            accepted_priors_list.append(accepted_priors)

            mse = np.square(accepted_priors-np.ones(d)).mean()
            print(r, "MSE",mse)
            MSE_list.append(mse)

        length_list = [len(l) for l in accepted_priors_list]
        print("# of acceptances",length_list)

        print("average # of acceptances", np.array(length_list).mean())
        print("average MSE", np.array(MSE_list).mean())

        print("time_list",time_list)
        print("average time", np.array(time_list).mean())


        main_name = dist_name + '_eps_' + str(epsilon) + '_ngen_' + str(n_gen) + '_niter_' + str(n_iter) + '_nruns_' + str(n_runs)

        name = 'MSE_' + main_name 
        dump(MSE_list, "save/" + name)

        name = 'accepted_priors_' + main_name 
        dump(accepted_priors_list, "save/" + name)

        name = 'time_' + main_name 
        dump(time_list,"save/" + name)



