import sys
import os

current_dir = os.getcwd()
smac_dir = current_dir + '/SMAC3'

sys.path.append(smac_dir)
import pandas as pd
import json
import numpy as np
from os.path import join, dirname
from scipy.stats import gaussian_kde,beta, norm, multivariate_normal
from smac.optimizer.acquisition import EI, PI, LogEI, LCB
from smac.epm.base_epm import AbstractEPM
from smac.configspace import ConfigurationSpace, Configuration

class Prior:
    
    def __init__(self,
                 cs,
                 names,
                 types, 
                 pdfs,
                 sampling_funcs,
                 ranges,
                 modes,
                 values,
                 use_log_evaluation=True # not implemented yet
                ):
        # needed to be able to sample from distribution - used in the sample function
        self.prior_floor = 1e-12

        # needed to be able to efficiently call the function - one per dimension, used in __call__
        self.cs = cs
        self.names = names
        self.types = types
        self.values = values
        self.pdfs = pdfs
        self.sampling_funcs = sampling_funcs
        self.ranges = ranges
        self.values = values
        # TODO fix this, it scales wrong for integers
        self.mode = np.array(modes)
        self.dims = len(self.mode)
        self.max = self(self.mode.reshape(1, -1), normalize=False, normalized_input=False)
        # Should normalize sometimes, other times it should not...

    def get_pdfs(self):
        return self.evaluate_functions
    
    def sample(self, size, normalize=True, as_config=True, seed=42):
        oversampling_factor = 10000
        samples = np.zeros((size*oversampling_factor, self.dims))

        for dim, func in enumerate(self.sampling_funcs):
            samples[:, dim] = func(size=size*oversampling_factor)

        in_bounds = np.array([True] * len(samples))
        
        # normalize samples to return values in 0, 1
        norm_samples = np.zeros(shape=(samples.shape))
        for dim, (range_, type_) in enumerate(zip(self.ranges, self.types)):
            if type_ != 'categorical':  
                lower, upper = range_
                param_in_bounds = (samples[:, dim] >= lower) & (samples[:, dim] <= upper)
                in_bounds = param_in_bounds & in_bounds
                norm_samples[:, dim] = (samples[:, dim] - lower) / (upper - lower) 
            else:
                norm_samples[:, dim] = samples[:, dim]
        
        if normalize:
            return_samples = norm_samples[in_bounds][0:size]
        else:
            return_samples = samples[in_bounds][0:size]

        if as_config:   
            return self._as_config(return_samples)
        else:
            return return_samples
    
    def _as_config(self, array):
        all_configs_list = []
        for arr_config in array:
            config_dict = {}
            for i in range(len(arr_config)):

                if self.values[i] is None:
                    config_dict[self.names[i]] = arr_config[i]
                else:
                    config_dict[self.names[i]] = self.values[i][np.round(arr_config[i]).astype(int)]
            all_configs_list.append(Configuration(self.cs, config_dict))
        return all_configs_list
            

    def __call__(self, X, normalize=True, normalized_input=True):
        # everything comes in in range (0,1), and then gets scaled up to the proper range of the function
        X_scaled = np.zeros(X.shape)
        if normalized_input:
            for i, type_ in enumerate(self.types):
                if type_ == 'integer':
                    tot_numbers = self.ranges[i][1] - self.ranges[i][0] + 1
                    # since integers are input from SMAC with distance to the boundary
                    X_scaled[:, i] = (X[:, i] - 1/(2*tot_numbers)) * (tot_numbers) + self.ranges[i][0]

                else:
                    X_scaled[:, i] = X[:, i] * (self.ranges[i][1] - self.ranges[i][0]) + self.ranges[i][0]
                
        else:
            X_scaled = X    
        
        # since several univariate distribution - compute across dimensions and return (assume independence between dims)
        probabilities = np.ones(len(X))
        # dimension-wise multiplication of the probabilities
        for i in range(self.dims):
            probabilities = probabilities * self.pdfs[i](X_scaled[:, i])
            
        if normalize:
             return probabilities.reshape(-1, 1) / self.max + self.prior_floor

        return probabilities.reshape(-1, 1)
    
    def get_max_location(self, as_config=True):
        if as_config:
            return self._as_config(self.mode.reshape(1, -1))
        return self.mode.reshape(1, -1)

    def get_max(self):
        return self.max[0][0]
    
    def get_min(self):
        return self.prior_floor

class PriorEI(EI):

    def __init__(self,
                 model: AbstractEPM,
                 prior: Prior,
                 beta = 1,
                 par: float = 0.0,
                 dim_decay = False,
                 discretize=True
                ):
        """Constructor

        Parameters
        ----------
        model : AbstractEPM
            A model that implements at least
                 - predict_marginalized_over_instances(X)
        par : float, default=0.0
            Controls the balance between exploration and exploitation of the
            acquisition function.
        prior : The user defined prior of the class Prior, either as a set of point probabilities 
        or an interpolated Gaussian KDE - either way, the prior must implement __call__
        """

        super(PriorEI, self).__init__(model)
        self.long_name = 'Prior-guided Expected Improvement'
        self.par = par
        self.eta = None
        self._required_updates = ('model', 'eta')
        self.prior = prior
        self.beta = beta
        self.t = 0
        self.dim_decay = dim_decay
        
    def update(self, **kwargs):
        super(EI, self).update(**kwargs)
        self.t += 1
        
    def _compute(self, X: np.ndarray) -> np.ndarray:
        ei_X = super(PriorEI, self)._compute(X)
        power = self.compute_power(X)
        # here, the X input is normalized. For the __call__ method in Prior, this needs to be taken into account
        prior_ei_X = np.power(self.prior(X), power) * ei_X
        return prior_ei_X
    
    def _compute_EI(self, X: np.ndarray) -> np.ndarray:
        ei_X = super(PriorEI, self)._compute(X)
        return ei_X
    
    def _compute_prior(self, X: np.ndarray) -> np.ndarray:
        power = self.compute_power(X)
        prior = np.power(self.prior(X), power)
        return prior

    def compute_power(self, X):
        if self.dim_decay:
            return self.beta/(self.t ** np.log(1 + X.shape[1]))
        else:
            return self.beta/self.t

    def compute_and_save(self):
        if sys.argv[2] == 'bran_debug':
            res  = {}
            X = np.linspace(0, 1 , 1001).reshape(-1, 1)
            res['comb'] = self._compute(X).reshape(-1)
            res['ei'] = self._compute_EI(X).reshape(-1)
            res['prior'] = self._compute_prior(X).reshape(-1)
            res['pp'] = np.power(res['prior'], self.compute_power(X)).reshape(-1)
            df = pd.DataFrame(res, index=list((range(len(res['comb'])))))
            df.to_csv(f'plot_data/beta{self.beta}iter{self.t}.csv')


class PriorLogEI(LogEI):

    def __init__(self,
                 model: AbstractEPM,
                 prior: Prior,
                 beta = 1,
                 par: float = 0.0,
                 dim_decay = False,
                 discretize=True,
                 space_fraction=50
                ):
        """Constructor

        Parameters
        ----------
        model : AbstractEPM
            A model that implements at least
                 - predict_marginalized_over_instances(X)
        par : float, default=0.0
            Controls the balance between exploration and exploitation of the
            acquisition function.
        prior : The user defined prior of the class Prior, either as a set of point probabilities 
        or an interpolated Gaussian KDE - either way, the prior must implement __call__
        """

        super(PriorLogEI, self).__init__(model)
        self.long_name = 'Prior-guided Log Expected Improvement'
        self.par = par
        self.eta = None
        self._required_updates = ('model', 'eta')
        self.prior = prior
        self.beta = beta
        self.t = 0
        self.dim_decay = dim_decay
        self.space_fraction = space_fraction
        self.SMALL_NUMBER = 1e-3
    
    def update(self, **kwargs):
        super(LogEI, self).update(**kwargs)
        self.t += 1
        
    def _compute(self, X: np.ndarray) -> np.ndarray:

        ei_X = super(PriorLogEI, self)._compute(X)
        if len(ei_X[ei_X < 0]) > 0:
            print(ei_X[ei_X < 0])
        power = self.compute_power(X)
        # here, the X input is normalized. For the __call__ method in Prior, this needs to be taken into account
        if self.space_fraction != 0:
            bins = np.linspace(np.power(self.prior.get_min(), power) , np.power(self.prior.get_max(), power) + self.SMALL_NUMBER,\
                np.ceil(self.space_fraction*self.beta/self.t + 1).astype(int))[1:]
            # adaptive binning : np.ceil(self.space_fraction*self.beta/self.t).astype(int)
            vals = np.digitize(np.power(self.prior(X, normalize=False), power), bins)
            #raise SystemError(f'{bins}   {self.prior.get_max()}')
            binned_values = bins[vals.reshape(-1)].reshape(-1, 1)
            prior_ei_X = binned_values * ei_X
        else:
            prior_ei_X = np.power(self.prior(X), power) * ei_X
        
        return prior_ei_X
    
    def _compute_EI(self, X: np.ndarray) -> np.ndarray:
        ei_X = super(PriorLogEI, self)._compute(X)
        return ei_X
    
    def _compute_prior(self, X: np.ndarray) -> np.ndarray:
        power = self.compute_power(X)
            
        if self.space_fraction != 0:
            bins = np.linspace(np.power(self.prior.get_min(), power) , np.power(self.prior.get_max(), power) + self.SMALL_NUMBER,\
            np.ceil(self.space_fraction*self.beta/self.t + 1).astype(int))[1:]
            # adaptive binning : np.ceil(self.space_fraction*self.beta/self.t).astype(int)
            vals = np.digitize(np.power(self.prior(X, normalize=False), power), bins)
            #raise SystemError(f'{bins}   {self.prior.get_max()}')
            binned_values = bins[vals.reshape(-1)].reshape(-1, 1)
            return binned_values
        else:
            return np.power(self.prior(X), power)
        

    def compute_power(self, X):
        if self.dim_decay:
            return self.beta/(self.t ** np.log(1 + X.shape[1]))
        else:
            return self.beta/self.t

    def compute_and_save(self):
        if sys.argv[2] == 'bran_debug':
            res  = {}
            X = np.linspace(0, 1 , 1001).reshape(-1, 1)
            res['comb'] = self._compute(X).reshape(-1)
            res['ei'] = self._compute_EI(X).reshape(-1)
            res['prior'] = self._compute_prior(X).reshape(-1)
            res['pp'] = np.power(res['prior'], self.compute_power(X)).reshape(-1)
            df = pd.DataFrame(res, index=list((range(len(res['comb'])))))
            df.to_csv(f'plot_data/beta{self.beta}iter{self.t}.csv')




class PriorLCB(LCB):

    def __init__(self,
                 model: AbstractEPM,
                 prior: Prior,
                 beta = 1,
                 par: float = 10.0,
                 dim_decay = False
                ):
        """Constructor

        Parameters
        ----------
        model : AbstractEPM
            A model that implements at least
                 - predict_marginalized_over_instances(X)
        par : float, default=0.0
            Controls the balance between exploration and exploitation of the
            acquisition function.
        prior : The user defined prior of the class Prior, either as a set of point probabilities 
        or an interpolated Gaussian KDE - either way, the prior must implement __call__
        """

        super(PriorLCB, self).__init__(model)
        self.long_name = 'Prior-guided Expected Improvement'
        self.eta = None
        self._required_updates = ('model', 'num_data')
        self.prior = prior
        self.beta = beta
        self.par = 1
        self.t = 0
        self.dim_decay = dim_decay
        self.dims = self.prior.dims
        self.point_grid = np.random.uniform(size=(20000, self.dims))

    def update(self, **kwargs):
        super(LCB, self).update(**kwargs)
        self.t += 1
        self.min = self._compute_EI(self.point_grid).min()
        print(self.min)

    def _compute(self, X: np.ndarray) -> np.ndarray:
        mean_X, std_X = super(PriorLCB, self)._compute(X, in_parts=True)
        power = self.compute_power(X)
        # here, the X input is normalized. For the __call__ method in Prior, this needs to be taken into account
        prior_ucb_X = np.power(self.prior(X), power) * std_X + mean_X
        return prior_ucb_X
    
    def _compute_EI(self, X: np.ndarray) -> np.ndarray:
        ucb_X = super(PriorLCB, self)._compute(X)
        return ucb_X
    
    def _compute_prior(self, X: np.ndarray) -> np.ndarray:
        power = self.compute_power(X)
        prior = np.power(self.prior(X), power)
        return prior

    def compute_power(self, X):
        if self.dim_decay:
            return self.beta/(self.t ** np.log(1 + X.shape[1]))
        else:
            return self.beta/self.t

    def compute_and_save(self):
        if sys.argv[2] == 'bran_debug':
            res  = {}
            X = np.linspace(0, 1 , 10001).reshape(-1, 1)
            res['comb'] = self._compute(X).reshape(-1)
            res['ei'] = self._compute_EI(X).reshape(-1) - self.min
            res['prior'] = self._compute_prior(X).reshape(-1)
            res['pp'] = np.power(res['prior'], self.compute_power(X)).reshape(-1)
            df = pd.DataFrame(res, index=list((range(len(res['comb'])))))
            df.to_csv(f'plot_data/beta{self.beta}iter{self.t}.csv')


class PriorPI(PI):

    def __init__(self,
                 model: AbstractEPM,
                 prior: Prior,
                 beta = 1,
                 par: float = 0.0,
                 dim_decay = False,
                 discretize=True
                ):
        """Constructor

        Parameters
        ----------
        model : AbstractEPM
            A model that implements at least
                 - predict_marginalized_over_instances(X)
        par : float, default=0.0
            Controls the balance between exploration and exploitation of the
            acquisition function.
        prior : The user defined prior of the class Prior, either as a set of point probabilities 
        or an interpolated Gaussian KDE - either way, the prior must implement __call__
        """

        super(PriorPI, self).__init__(model)
        self.long_name = 'Prior-guided Expected Improvement'
        self._required_updates = ('model', 'eta')
        self.prior = prior
        self.beta = beta
        self.t = 0
        self.dim_decay = dim_decay
        
    def update(self, **kwargs):
        super(PI, self).update(**kwargs)
        self.t += 1
        
    def _compute(self, X: np.ndarray) -> np.ndarray:
        ei_X = super(PriorPI, self)._compute(X)
        power = self.compute_power(X)
        # here, the X input is normalized. For the __call__ method in Prior, this needs to be taken into account
        prior_ei_X = np.power(self.prior(X), power) * ei_X
        return prior_ei_X
    
    def _compute_EI(self, X: np.ndarray) -> np.ndarray:
        ei_X = super(PriorPI, self)._compute(X)
        return ei_X
    
    def _compute_prior(self, X: np.ndarray) -> np.ndarray:
        power = self.compute_power(X)
        prior = np.power(self.prior(X), power)
        return prior

    def compute_power(self, X):
        if self.dim_decay:
            return self.beta/(self.t ** np.log(1 + X.shape[1]))
        else:
            return self.beta/self.t

    def compute_and_save(self):
        if sys.argv[2] == 'bran_debug':
            res  = {}
            X = np.linspace(0, 1 , 1001).reshape(-1, 1)
            res['comb'] = self._compute(X).reshape(-1)
            res['ei'] = self._compute_EI(X).reshape(-1)
            res['prior'] = self._compute_prior(X).reshape(-1)
            res['pp'] = np.power(res['prior'], self.compute_power(X)).reshape(-1)
            df = pd.DataFrame(res, index=list((range(len(res['comb'])))))
            df.to_csv(f'plot_data/beta{self.beta}iter{self.t}.csv')

