from tqdm import tqdm as tqdm
import numpy as np

import torch
import torch.nn as nn
from GPyOpt.methods import BayesianOptimization
from bed import methods
from copy import deepcopy

import pandas as pd
import matplotlib.pyplot as plt









class BED:
    def __init__(
            self, model, simulator, prior, device='cpu'):
        self.model = model
        self.simulator = simulator
        self.prior = prior
        self.device = device
        self.design_array = []
        self.utility_array = []
        
        self.best_lb = 0
        self.best_d = None
 
    def train(self):
        pass

    def _train_MI_estimator(self, X, Y):            
        # train MI estimator
        self.model.to(self.device).learn(X, Y)

        # train MINE model
        mi = self.model.MI(X, Y)
        return mi







class GradientFreeBED(BED):
    def __init__(
            self, model, simulator, prior, domain, constraints=None, device='cpu'):
        super(GradientFreeBED, self).__init__(model, simulator, prior, device)
        self.domain = domain
        self.constraints = constraints

    def _compute_optimal_design(self, obj):
        """Computes the optimal design after training a GP model."""
        self.d_opt = methods.get_GP_optimum(obj)
        
    def _preprocess_data(self, X, Y):
        # numpy -> torch
        device = self.device
        X, Y = torch.Tensor(X.astype(float)).float().to(device),  torch.Tensor(Y.astype(float)).float().to(device)
        
        # atleast 2d
        X, Y = X.view(len(X), -1), Y.view(len(Y), -1)

        # padding x or y so that they have the same dimensionality
        n, dim_x, dim_y = len(X), X.size()[1], Y.size()[1]
        if dim_x > dim_y: Y = torch.cat([Y, torch.randn(n, dim_x - dim_y).to(device)], dim=1)
        if dim_x < dim_y: X = torch.cat([X, torch.randn(n, dim_y - dim_x).to(device)], dim=1)
        return X, Y
    
    def _objective(self, d):
        """Objective function to be maximised during Bayesian Optimisation."""
        
        # simulate data
        X, Y = self.prior, self.simulator.sample_data(d=d.flatten(), p=self.prior)
        X, Y = self._preprocess_data(X, Y)
        
        # evaluate MI
        lb = self._train_MI_estimator(X, Y)
        if lb>self.best_lb:
            self.best_model_state_dict = self.model.state_dict()
            self.best_lb = lb
            self.best_d = d
        
        # record utility and design
        self.design_array.append(d)
        self.utility_array.append(lb)
        print('lb=', lb, 'd=', d, '\n')
        return lb

    def train(
            self, bo_model=None, bo_space=None, bo_acquisition=None,
            X_init=None, Y_init=None, BO_init_num=5, BO_max_num=20,
            verbosity=False):
        """
        Uses Bayesian optimisation to find the optimal design. The objective
        function is the mutual information lower bound at a particular design,
        obtained by training a MINE model.

        Parameters
        ----------
        bo_model:

        bo_space:

        bo_acquisition:

        BO_init_num: int
            The number of initial BO evaluations used to initialise the GP.
            (default is 5)
        BO_max_num: int
            The maximum number of BO evaluations after the initialisation.
            (default is 20)
        verbosity: boolean
            Turn off/on output to the command line.
            (default is False)
        """

        if verbosity:
            print('Initialize Probabilistic Model')

        if bo_model and bo_space and bo_acquisition:
            raise NotImplementedError('Custom BO model not yet implemented.')
        elif all(v is None for v in [bo_model, bo_space, bo_acquisition]):
            pass
        else:
            raise ValueError(
                'Either all BO arguments or none need to be specified.')

        # Define GPyOpt Bayesian Optimization object
        self.bo_obj = BayesianOptimization(
            f=self._objective, domain=self.domain,
            constraints=self.constraints, model_type='GP',
            acquisition_type='EI', normalize_Y=False,
            initial_design_numdata=BO_init_num, acquisition_jitter=0.01,
            maximize=True, X=X_init, Y=Y_init)

        if verbosity:
            print('Start Bayesian Optimisation')

        # run the bayesian optimisation
        self.bo_obj.run_optimization(
            max_iter=BO_max_num, verbosity=verbosity, eps=1e-5)

        # find optimal design from posterior GP model; stored as d_opt
        self._compute_optimal_design(self.bo_obj)
        
        # eval the utility of the optimal design
        self._objective(self.d_opt)
          
    def train_with_optimal_design(self):
        # A. historically best
        self.model.load_state_dict(self.best_model_state_dict)
        lb1 = self._objective(self.best_d)
        print('lb historical d-opt', lb1)
        
        # B. GP's predict best (might be inaccurate if n_eval too small)
        lb2 = self._objective(self.d_opt)
        print('lb gp d-opt', lb2)
        

        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
        
    # ---------------------------- plotting facilities ---------------------------- #
        
    def plot_gp_mean_function(self, filename=None, label_x=r'$d_1$', label_y=r'$d_2$', negative=True):
        bo_obj = self.bo_obj
        model_to_plot = bo_obj.model
        
        def _plot_acquisition(bounds, input_dim, model, Xdata, Ydata, acquisition_function, suggested_sample,
                         filename=None, label_x=None, label_y=None, plot_data=False, color_by_step=True):
            '''
            Plots of the model and the acquisition function in 1D and 2D examples.
            '''

            if input_dim == 2:
                if not label_x:
                    label_x = 'X1'
                if not label_y:
                    label_y = 'X2'
                # preparation
                n = Xdata.shape[0]
                colors = np.linspace(0, 1, n)
                cmap = plt.cm.coolwarm                          # see here: https://matplotlib.org/stable/users/explain/colors/colormaps.html
                norm = plt.Normalize(vmin=0, vmax=1)
                points_var_color = lambda X: plt.scatter(
                    X[:,0], X[:,1], c=colors, label=u'Observations', cmap=cmap, norm=norm)
                points_one_color = lambda X: plt.plot(
                    X[:,0], X[:,1], 'r.', markersize=10, label=u'Observations')
                # data
                X1 = np.linspace(bounds[0][0], bounds[0][1], 200)
                X2 = np.linspace(bounds[1][0], bounds[1][1], 200)
                x1, x2 = np.meshgrid(X1, X2)
                X = np.hstack((x1.reshape(200*200,1),x2.reshape(200*200,1)))

                m, v = model.predict(X)
                if negative: m = -m
                plt.figure(figsize=(5,5))
                plt.contourf(X1, X2, m.reshape(200,200),100, cmap=cmap)
                plt.colorbar()
                if plot_data:
                    if color_by_step:
                        points_var_color(Xdata)
                    else:
                        points_one_color(Xdata)
                plt.xlabel(label_x)
                plt.ylabel(label_y)
                plt.tight_layout()
                #plt.title('Posterior mean')
                plt.axis((bounds[0][0],bounds[0][1],bounds[1][0],bounds[1][1]))
                if filename!=None:
                    savefig(filename)
            return plt
        
        return _plot_acquisition(bo_obj.acquisition.space.get_bounds(),
                                model_to_plot.model.X.shape[1],
                                model_to_plot.model,
                                model_to_plot.model.X,
                                model_to_plot.model.Y,
                                bo_obj.acquisition.acquisition_function,
                                bo_obj.suggest_next_locations(),
                                filename,
                                label_x,
                                label_y) 
    
    def plot_designs(self, plot, points, names):
        markers = ['^', 'd', 'o', 's']
        for i, point in enumerate(points):
            plot.scatter(point[0], point[1], s=50, edgecolors='k', facecolors='none', marker=markers[i], label=names[i])
        plot.legend()
        return plot
    
    
    # ---------------------------- save & load facilities ---------------------------- #
    
    def save_BO_model(self, fn):
        self.bo_obj.save_evaluations(fn)
        
    def load_BO_model(self, fn):
        evals = pd.read_csv(fn, index_col=0, delimiter="\t")
        Y = np.array([[x] for x in evals["Y"]])
        X = np.array(evals.filter(regex="var*"))
        
        
        print('X', X.shape)
        print('Y', Y.shape)
        
        
        self.bo_obj = BayesianOptimization(
            f=self._objective, domain=self.domain,
            constraints=self.constraints, model_type='GP',
            acquisition_type='EI', normalize_Y=False,
            initial_design_numdata=5, acquisition_jitter=0.01,
            maximize=True, X=X, Y=Y)
        self.bo_obj.run_optimization(max_iter=0, verbosity=True, eps=1e-5)     # <-- dummy run to create the GP inside

        

        
        
        
        
        
        


    
    
    

    
#     def explore_loss_landscape(self, domains, n_per_dim=8):
#         dim = len(domains)
#         x_array = []
#         for domain in domains:
#             l, u = domain[0], domain[1]
#             x_array.append(np.linspace(l, u, n_per_dim))
            
#         designs = []
#         utilities = []
#         for x1 in x_array[0]:
#             for x2 in x_array[1]:
#                 d = np.array([x1, x2])
#                 if x1>=x2: continue
#                 print('d=', d)
#                 mi = self._objective(d)
#                 designs.append(d)
#                 utilities.append(mi)
                
#         X = np.array(designs)
#         Y = np.array(utilities).reshape(-1, 1)
    
#         print('X', X.shape)
#         print('Y', Y.shape)
        
#         self.bo_obj = BayesianOptimization(
#             f=self._objective, domain=self.domain,
#             constraints=self.constraints, model_type='GP',
#             acquisition_type='EI', normalize_Y=False,
#             initial_design_numdata=5, acquisition_jitter=0.01,
#             maximize=True, X=X, Y=Y)
#         self.bo_obj.run_optimization(max_iter=0, verbosity=True, eps=1e-5)     # <-- dummy run to create the GP inside
    
    
    
    
    

