import ot
import torch
from scipy.stats import norm

def pdf(x,accepted_priors):
    val = 0
    for theta in accepted_priors:
        val = val + norm.pdf(x,loc=theta)
    return val/len(accepted_priors)

def emd(x,y):
    return ot.emd2_1d(x.numpy().reshape(-1,1),y.numpy().reshape(-1,1))

def gaussian_generation(parameter,n_gen,d):
    generated = parameter + torch.randn(n_gen,d)
    return generated

def normal_prior_sampling(d,sigma=5,mean=0):
    return mean+sigma*torch.randn(1,d)

def abc_posterior(dist,data,prior_sample_func,generating_data_func,n_gen,epsilon,n_iter):
    d = data.shape[1]
    accepted_priors = []
    accepted_generated = []
    for _ in range(n_iter):
        prior_parameter_sampled = prior_sample_func(d)
        generated_data = generating_data_func(prior_parameter_sampled,n_gen,d)
        abc_inner(dist,data,n_gen,epsilon,accepted_priors,accepted_generated,
                   prior_parameter_sampled,generated_data)
    return accepted_priors,accepted_generated


def abc_inner(dist,data,n_gen,epsilon,accepted_priors,accepted_generated,
                   prior_parameter_sampled,generated_data):

    diff = dist(generated_data, data)
    if diff<epsilon:
        accepted_priors.append(prior_parameter_sampled)
        accepted_generated.append(generated_data)


def abc_posterior_while(dist,data,prior_sample_func,generating_data_func,n_gen,epsilon,n_iter):
    d = data.shape[1]
    accepted_priors = []
    accepted_generated = []
    count = 0
    while (count<n_iter):
        prior_parameter_sampled = prior_sample_func(d)
        generated_data = generating_data_func(prior_parameter_sampled,n_gen,d)
        count = abc_inner_while(dist,data,n_gen,epsilon,accepted_priors,accepted_generated,
                   prior_parameter_sampled,generated_data,count)
    return accepted_priors,accepted_generated

def abc_inner_while(dist,data,n_gen,epsilon,accepted_priors,accepted_generated,
                   prior_parameter_sampled,generated_data,count):

    diff = dist(generated_data, data)
    if diff<epsilon:
        accepted_priors.append(prior_parameter_sampled)
        accepted_generated.append(generated_data)
        count = count + 1
    return count
