import numpy as np
from experiments.target_lnpdfs.LogisticRegression import LNPDF
import tensorflow_probability as tfp
import tensorflow as tf
tfd = tfp.distributions
class GMM_LNPDF(LNPDF):
    def __init__(self, target_weights, target_means, target_covars):
        self.target_weights = target_weights
        self.target_means = target_means
        self.target_covars = target_covars
        self.target_means = target_means.astype(np.float32)
        self.target_covars = target_covars.astype(np.float32)
        self.gmm =tfp.distributions.MixtureSameFamily(
                  mixture_distribution=tfp.distributions.Categorical(logits=np.log(target_weights).astype(np.float32)),
                  components_distribution=tfp.distributions.MultivariateNormalTriL(
                  loc=target_means.astype(np.float32), scale_tril=np.linalg.cholesky(target_covars).astype(np.float32)))

    def log_density(self, x):
        x = tf.cast(x,dtype=tf.float32)
        return self.gmm.log_prob(x)

    def get_num_dimensions(self):
        return len(self.target_means[0])

    def can_sample(self):
        return True

    def sample(self, n):
        return self.gmm.sample(n)

def make_target(num_dimensions):
    num_true_components = 10
    weights = np.ones(num_true_components) / num_true_components
    means = np.empty((num_true_components, num_dimensions))
    covs = np.empty((num_true_components, num_dimensions, num_dimensions))
    for i in range(0, num_true_components):
        means[i] = 100 * (np.random.random(num_dimensions) - 0.5)
        covs[i] = 0.1 * np.random.normal(0, num_dimensions, (num_dimensions * num_dimensions)).reshape(
            (num_dimensions, num_dimensions))
        covs[i] = covs[i].transpose().dot(covs[i])
        covs[i] += 1 * np.eye(num_dimensions)
    return GMM_LNPDF(weights, means, covs)



