import numpy as np 
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import norm
import torch

from mixture_functions import sample_mixture, mixture_pdf
from utils import seed_everything, kl_mvn, bhatt_dist
from uncertainty_estimator import MixtureEntropyEstimator, EpistemicUncertaintyEstimator

seed_everything(42)
dimensions = 10
mixture_components = 100

means = torch.normal(0, np.exp(0), size=(mixture_components, dimensions))
#stds = torch.normal(0.7, 0.4, size=(mixture_components, dimensions))**2
stds = torch.ones((mixture_components, dimensions))
mixture_weights = [1/mixture_components]*mixture_components

if dimensions == 1:
    x_axis = np.arange(-10, 10, 0.001)
    for mc in range(mixture_components):
        plt.plot(x_axis, norm.pdf(x_axis,means[mc],covs[mc]))
    plt.show()

numb = 1000000
samp, comps = sample_mixture(means, stds, mixture_weights, numb)
if dimensions == 1:
    ax = sns.histplot(samp, bins=60, kde=True, stat='density')
    ax.lines[0].set_color('crimson')

import pdb; pdb.set_trace()
mee = MixtureEntropyEstimator('elk')
mee.lower_bound(samp, comps, means.unsqueeze(0),stds.unsqueeze(0), mixture_weights, '')
eue = EpistemicUncertaintyEstimator('kl_mean')
ent_estimate = mee.estimate_entropy(samp, comps, means, stds, mixture_weights)
eue_estimate = eue.estimate_epi_uncertainty(samp, comps, means, stds, mixture_weights, method='kl_mean')

print(f'Entropy: {ent_estimate}')
print(f'Epistemic Uncertainty: {eue_estimate}')
