import numpy as np 
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import norm
import torch

from mixture_functions import sample_mixture, mixture_pdf
from utils import seed_everything, kl_mvn, bhatt_dist
from uncertainty_estimator import MixtureEntropyEstimator


seed_everything(42)
dimensions = 10
mixture_components = 1000

means = torch.normal(0, np.exp(0), size=(mixture_components, dimensions))
#stds = torch.normal(0.7, 0.4, size=(mixture_components, dimensions))**2
stds = torch.ones((mixture_components, dimensions))
mixture_weights = [1/mixture_components]*mixture_components


numb = 1000000
samp, comps = sample_mixture(means, stds, mixture_weights, numb)


mee = MixtureEntropyEstimator('wtf')
means = means.unsqueeze(0)
stds = stds.unsqueeze(0)
bhatt_exp = mee.bhatt_exp(samp, comps, means, stds, mixture_weights,0)
bhatt_exp_faster = mee.bhatt_exp_faster(samp, comps, means, stds, mixture_weights,0)
print(f'bhatt exp: {bhatt_exp}')
print(f'bhatt exp_faster: {bhatt_exp_faster}')
