import os
import itertools
from time import time
from datetime import timedelta

import torch
import numpy as np
import numpy.matlib
import matplotlib.pyplot as plt

from plot_functions import plot_uncertainty_1d  
from utils import kl_mvn, Wasserstein_GP, entropy_utils
from uncertainty_estimator import EpistemicUncertaintyEstimator

def find_best_points(samp, numb_points, model, input_preproc, 
        ensemble_size, device, nflows=False, 
        acquisition_criteria = 'mutual_info', numb_points_2_add = 10):
    state = samp[0]
    action = samp[1]
    model_inp = np.hstack([state, action])
    model_inp = torch.tensor(model_inp).type(torch.float32).to(device)
    model_inp = input_preproc(model_inp, model.stats_inputs)
    eue = EpistemicUncertaintyEstimator(acquisition_criteria)
    uncertainty, time_taken = estimate_uncertainty(model_inp, acquisition_criteria, model, 
        ensemble_size, numb_samps = numb_points, epi_estimator = eue, 
        estimator_types = [], nflows=nflows)
    uncertainty = uncertainty[acquisition_criteria]
    ind = np.argpartition(uncertainty, -numb_points_2_add)[-numb_points_2_add:]
    points_2_add = []
    for i in ind:
        point = tuple((s[i] for s in samp))
        points_2_add.append(point)
    return points_2_add, time_taken

def find_best_points_1d(samp, numb_points, model, 
        input_preproc, ensemble_size, device, nflows = False,
        acquisition_criteria = 'mutual_info', numb_points_2_add = 10):
    X = samp[0].reshape(-1,1)
    X = torch.tensor(X, dtype = torch.float32).to(device)
    X = input_preproc(X, model.stats_inputs)
    eue = EpistemicUncertaintyEstimator(acquisition_criteria)
    uncertainty, time_taken = estimate_uncertainty(X, acquisition_criteria, model, 
        ensemble_size, numb_samps = numb_points, epi_estimator = eue, 
        estimator_types = [], nflows=nflows)
    uncertainty = uncertainty[acquisition_criteria]
    ind = np.argpartition(uncertainty, -numb_points_2_add)[-numb_points_2_add:]
    points_2_add = (samp[0][ind], samp[1][ind])
    return points_2_add, time_taken


def estimate_uncertainty(model_inp, uncertainty_type, model, 
        ensemble_size, numb_samps=10000, nflows=False, epi_estimator=[], estimator_types=[]):
    uncertainty = {}
    if uncertainty_type in ['all', 'mutual_info', 'alea_unc', 'tot_unc']:
        means = []
        stds = []
        output = []
        log_probs_base = []
        component_ent = []
        dep_t0 = time() 
        mut_info = []
        chunk_size = 1000
        for i in range(ensemble_size):
            kwargs = {'rand_mask': False, 'mask_index': i}
            chunks = int(np.ceil(numb_samps/chunk_size))
            output_hat = []
            log_prob_base = []
            for j in range(chunks):
                (output_hat_ch, log_prob_nflows, base_mean, base_std) = (
                    model.sample_and_log_prob(chunk_size, context = model_inp, 
                    kwargs=kwargs, ensemble = True, ensemble_size = ensemble_size))
                output_hat.append(output_hat_ch.detach().cpu().numpy())
                log_prob_base.append(log_prob_nflows.detach().cpu().numpy())
            output_hat = np.hstack(output_hat)
            log_prob_base = np.hstack(log_prob_base)
            mu = base_mean[::chunk_size]
            sig = base_std[::chunk_size]
            mu = mu.detach().cpu().numpy()
            sig = sig.detach().cpu().numpy()
            output.append(output_hat)
            log_probs_base.append(log_prob_base)
            ent = mu.shape[1]/2*np.log(2*np.pi*np.e)+1/2*np.log((sig**2).prod(1))
            component_ent.append(ent)
            means.append(mu)
            stds.append(sig)
        output = np.hstack(output)
        log_probs_base = np.hstack(log_probs_base)
        output = output.reshape(output.shape[1], output.shape[0], output.shape[2])
        if not nflows:
            log_probs = []
            for i in range(ensemble_size):
                norm_rv = torch.distributions.normal.Normal(torch.tensor(means[i]), 
                    torch.tensor(stds[i]))
                log_prob_comp = 1/ensemble_size*torch.exp(norm_rv.log_prob(torch.tensor(output)).sum(2))
                log_probs.append(log_prob_comp)
            log_probs = torch.stack(log_probs)
            log_probs = log_probs.numpy()
            log_probs = np.log(log_probs.sum(0))
            total_ent = -numpy.ma.masked_invalid(log_probs).mean(0).data
        else:
            total_ent = -numpy.ma.masked_invalid(log_probs_base).mean(1).data
        uncertainty['total_unc'] = total_ent
        alea_ent = np.stack(component_ent).mean(0)
        epi_ent = total_ent-alea_ent
        uncertainty['mutual_info'] = epi_ent
        uncertainty['alea_unc'] = alea_ent
        dep_t1 = time()
        print(f'Mutual Info time: {str(timedelta(seconds=(dep_t1-dep_t0)))}')
        seconds_taken = dep_t1-dep_t0
    if uncertainty_type=='all' or not uncertainty_type in ['mutual_info', 'alea_unc', 'tot_unc']:
        means = []
        stds = []
        pairwise_t0 = time()
        numb_samps = 1
        for i in range(ensemble_size):
            kwargs = {'rand_mask': False, 'mask_index': i}
            _, _, mu, sig = model.model.sample(numb_samps, 
                context = model_inp, kwargs=kwargs)
            mu = mu.detach().cpu().numpy()
            sig = sig.detach().cpu().numpy()
            means.append(mu)
            stds.append(sig)
        mus = np.stack(means, axis=1)
        sigs = np.stack(stds, axis=1)
        mus = torch.tensor(mus)
        sigs = torch.tensor(sigs)
        weights = [1/ensemble_size for i in range(ensemble_size)]
        mi = epi_estimator.estimate_epi_uncertainty('', '', mus, sigs, 
                weights)
        uncertainty[epi_estimator.estimator] = mi
        for et in estimator_types:
            mi = epi_estimator.estimate_epi_uncertainty('', '', mus, sigs, 
                    weights, method = et)
            uncertainty[et] = mi
        pairwise_t1 = time()
        print(f'Pairwise time: {str(timedelta(seconds=(pairwise_t1-pairwise_t0)))}')
        seconds_taken = pairwise_t1-pairwise_t0
    return uncertainty, seconds_taken

def check_uncertainty_1d(model, store_dir, input_preproc, 
        ensemble_size, device, suffix, numb_samples, 
        env='bimodal', nflows=False):
    if env == 'bimodal':
        x_cord = np.linspace(0, 4, 300)
    elif env == 'hetero':
        #x_cord = np.linspace(-4.5, 4.5, 300)
        x_cord = np.linspace(-5.5, 5.5, 300)
    model_inp = torch.tensor(x_cord).type(torch.float32).to(device)
    model_inp = model_inp.reshape(-1,1)
    model_inp = input_preproc(model_inp, model.stats_inputs)
    eue = EpistemicUncertaintyEstimator('kl_exp')
    estimator_types = ['bhatt_exp']
    uncertainty, time_taken = estimate_uncertainty(model_inp, 'all', model,
        ensemble_size, numb_samps = numb_samples, nflows = nflows, 
        epi_estimator = eue, estimator_types=estimator_types)
    plot_uncertainty_1d(uncertainty['mutual_info'], uncertainty['kl_exp'], uncertainty['bhatt_exp'],
        x_cord, store_dir, 'base_'+suffix)

