import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import optim
import seaborn as sns

from utils import normalize, un_normalize, seed_everything
from estimate_uncertainty import estimate_uncertainty_ensemble_nflows
from envs_1d import hetero_samp, bimodal_samp
from nflows_ensemble_model import nflows_ensemble
from uncertainty_estimator import MixtureEntropyEstimator, EpistemicUncertaintyEstimator

seed_everything(43)

x_hetero = np.linspace(-4.5,4.5,100)
train_data = bimodal_samp(200)
#train_data = hetero_samp(200)
#no_noise_y_hetero = 7*np.sin(x_hetero)+3*np.abs(np.cos(x_hetero/2))
input_preproc = normalize
output_preproc = normalize
input_postproc = un_normalize
output_postproc = un_normalize
datarange = np.linspace(train_data[0].min(), train_data[0].max(), 100)

num_layers = 1
hids = 200
ensemble_size = 5
dropout_masks = True
output_dim = 1
context_dim = 1
device = 'cpu'
bins = 5
domain = 1.2
lr = 0.0005
base = True
#epochs = 6000
epochs = 60

nflows_base = nflows_ensemble(num_layers, hids, output_dim, context_dim,
                bins, domain, lr, device, input_preproc, output_preproc,
                base = base, fixed_masks = dropout_masks,
                ensemble_size = ensemble_size)

train_loss = nflows_base.train_1d(epochs, train_data, output_postproc)


eue = EpistemicUncertaintyEstimator('wasserstein_exp')
#eue = EpistemicUncertaintyEstimator('bhatt_exp')
inps = torch.tensor(train_data[0], dtype = torch.float32)
inps = nflows_base.input_preproc(inps, nflows_base.stats_inputs)
fit_data_base = nflows_base.sample(1, inps.reshape(-1,1), ensemble_size=ensemble_size)
fit_data_base = fit_data_base[0].detach().cpu()
fit_data_base = output_postproc(fit_data_base.reshape(-1,1), nflows_base.stats_outputs)
sample_size = 2000
inps = torch.tensor(datarange, dtype=torch.float32).reshape(-1,1)
inps_normed = input_preproc(inps, nflows_base.stats_inputs)
uncertainty_estimates_nflows_base = estimate_uncertainty_ensemble_nflows(inps_normed, 'mutual_info',
        nflows_base, eue, ensemble_size, numb_samps= sample_size)
