import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import optim
import seaborn as sns

from utils import normalize, un_normalize, seed_everything
from estimate_uncertainty import estimate_uncertainty_ensemble_nflows
from envs_1d import hetero_samp, bimodal_samp
from nflows_ensemble_model import nflows_ensemble


seed_everything(43)

x_hetero = np.linspace(-4.5,4.5,100)
train_data = bimodal_samp(200)
no_noise_y_hetero = 7*np.sin(x_hetero)+3*np.abs(np.cos(x_hetero/2))
input_preproc = normalize
output_preproc = normalize
input_postproc = un_normalize
output_postproc = un_normalize
datarange = np.linspace(train_data[0].min(), train_data[0].max(), 100)


num_layers = 1
hids = 200
ensemble_size = 5
dropout_masks = True
output_dim = 1
context_dim = 1
device = 'cpu'
bins = 5
domain = 1.2
lr = 0.0005
base = True
gamma = 0.999
epochs = 6000

nflows_base = nflows_ensemble(num_layers, hids, output_dim, context_dim,
                bins, domain, lr, device, input_preproc,
                output_preproc, base = base,
                fixed_masks = dropout_masks, ensemble_size = ensemble_size)


train_loss = nflows_base.train_1d(epochs, train_data, output_postproc)
plt.plot(train_loss)
plt.show()

import pdb; pdb.set_trace()

inps = torch.tensor(train_data[0], dtype = torch.float32)
inps = nflows_base.input_preproc(inps, nflows_base.stats_inputs)
fit_data_base = nflows_base.sample(1, inps.reshape(-1,1), ensemble_size=ensemble_size)
fit_data_base = fit_data_base[0].detach().cpu()
fit_data_base = output_postproc(fit_data_base.reshape(-1,1), nflows_base.stats_outputs)
sample_size = 2000
inps = torch.tensor(datarange, dtype=torch.float32).reshape(-1,1)
inps_normed = input_preproc(inps, nflows_base.stats_inputs)
uncertainty_estimates_nflows_base = estimate_uncertainty_ensemble_nflows(inps_normed, 'mutual_info',
        nflows_base, ensemble_size, numb_samps= sample_size)

fig, ax = plt.subplots(1, 2, figsize=(12, 4), sharex=True)
sns.histplot(data = train_data[0], kde=True,
             color = 'darkblue', stat = 'probability',
             bins=100, ax=ax[0])
ax[0].title.set_text('Training Density')
ax[1].plot(datarange, uncertainty_estimates_nflows_base['mutual_info_base'].reshape(-1))
ax[1].title.set_text('Nflows Ensemble Base')
fig.tight_layout()
plt.show()
