import numpy as np


def gmm_likelihood_diagonal_Sigma(samples, mu, sigma):
    T, N = mu.shape

    sqrt_det_Sigma = sigma ** N
    inv_Sigma = 1 / (sigma ** 2)

    E = np.empty(shape=(samples.shape[0], mu.shape[0]))
    for i in range(samples.shape[0]):
        for t in range(mu.shape[0]):
            E[i, t] = np.sum((samples[i, :] - mu[t, :]) ** 2 * inv_Sigma)
    L = np.exp(-0.5 * E) / sqrt_det_Sigma
    return np.sum(L, axis=1) / T


def eval_likelihood_gmm_for_diagonal_cov_np(z, mu, std):
    T = mu.shape[0]
    mu = mu.reshape((1, T, -1))

    vec = z - mu
    precision = 1 / (std ** 2)
    precision = np.tile(np.diag(precision[0, :]), reps=(precision.shape[0], 1, 1))

    prec_vec = np.einsum('zij,azj->azi', precision, vec)
    exponent = np.einsum('abc,abc->ab', vec, prec_vec)
    sqrt_det_of_cov = np.prod(std, axis=1)
    likelihood = np.exp(-0.5 * exponent) / sqrt_det_of_cov
    return likelihood.sum(axis=1) / T


def clean_from_outliers_np(prior, posterior):
    nonzeros = (prior != 0)
    if any(prior == 0):
        prior = prior[nonzeros]
        posterior = posterior[nonzeros]
    outlier_ratio = (1 - nonzeros.astype(np.float64)).mean()
    return prior, posterior, outlier_ratio


def calc_kl_mc_np(mu_inf, cov_inf, mu_gen, cov_gen):
    mc_n = 1000
    t = np.random.random_integers(0, mu_inf.shape[0] - 1, size=(mc_n,))

    std_inf = np.sqrt(cov_inf)
    std_gen = np.sqrt(cov_gen)

    z_sample = (mu_inf[t] + std_inf[t] * np.random.normal(size=mu_inf[t].shape)).reshape((mc_n, 1, -1))

    prior = eval_likelihood_gmm_for_diagonal_cov_np(z_sample, mu_gen, std_gen)
    posterior = eval_likelihood_gmm_for_diagonal_cov_np(z_sample, mu_inf, std_inf)
    prior, posterior, outlier_ratio = clean_from_outliers_np(prior, posterior)
    kl_mc = np.mean(np.log(posterior) - np.log(prior), axis=0)
    return kl_mc, outlier_ratio


def empirical_KL(X_pred, X_true, cov_pred=1, cov_true=1):
    data_gen = X_pred
    scaling_inf = cov_pred
    scaling_gen = cov_true
    mu_inf = data_gen
    cov_inf = scaling_inf * np.ones_like(mu_inf)
    mu_gen = X_true
    cov_gen = scaling_gen * np.ones_like(mu_gen)

    kl_mc, _ = calc_kl_mc_np(mu_inf, cov_inf, mu_gen, cov_gen)
    return kl_mc
