
import json
import os
from os.path import dirname, join
import sys
from functools import partial
import numpy as np
from scipy.stats import norm
from scipy.linalg import norm as eucl # euclidian distance
import pandas as pd
from prior_acq import Prior
from warping import WarpedSpace
from benchmarks.branin.branin import setup_branin
from benchmarks.branin.branin_debug import setup_branin as setup_branin_debug
from smac.configspace import ConfigurationSpace
from ConfigSpace.hyperparameters import UniformFloatHyperparameter
emukit_dir = os.path.join(os.getcwd(), 'benchmarks/emukit/')
sys.path.append(emukit_dir)
from benchmarks.emukit.emukit.examples.profet.meta_benchmarks import (meta_forrester, meta_fcnet, meta_svm, meta_xgboost)


# Sets up the function, retrieves the prior and the parameters for all benchmarks and priors
def define_function_and_prior(function_id, function_family, path_to_files, method, prior_strength='strong', prior_index=0):
    """
    Get black box function and parameter space from one of the supported benchmarks: forrester, svm, xgboost, fcnet.
    """
    fname_objective = "%s/samples/%s/sample_objective_%d.pkl" % (path_to_files, function_family, function_id)
    fname_cost="%s/samples/%s/sample_cost_%d.pkl" % (path_to_files, function_family, function_id)
    prior_data_path = f'noisy_priors/{function_family}/{prior_strength}/prior{prior_index}.json'
    
    if prior_strength != 'noprior':
        with open(prior_data_path, 'r') as f:
            prior_config = json.load(f)
            #print(f'Reading the {function_family} {prior_strength} prior, index {prior_index}.')

    else:
        prior_config = None

    if function_family == "branin":
        fcn , parameter_space = setup_branin()
    elif function_family == "bran_debug":
        fcn , parameter_space = setup_branin_debug()
    elif function_family == "svm":
        fcn, parameter_space = meta_svm(fname_objective=fname_objective, fname_cost=fname_cost, noise=False)
    elif function_family == "xgboost":
        fcn, parameter_space = meta_xgboost(fname_objective=fname_objective, fname_cost=fname_cost, noise=False)
    elif function_family == "fcnet":
        fcn, parameter_space = meta_fcnet(fname_objective=fname_objective, fname_cost=fname_cost, noise=False)
    elif function_family == "unet":
        fcn, parameter_space = setup_unet()
    
    return fcn, parameter_space, prior_config


# moves the means of the prior to create a reasonable noise level - either as set percentage from the optimum 
# ('warping', implemented in Prior BO with Input Warping, competing paper) or as in BOPrO, where the noise
# is just sampled from a normal distribution centered at the optimum
# returns a new dictionary describing the prior, but now with offset means
def move_prior_means(prior_dict, method, noise_level=0):
    def is_allowed(prior_means, prior_ranges):
        for mean, range_ in zip(prior_means, prior_ranges):
            if np.any([mean < range_[0], mean > range_[1]]):
                return False
        return True
    
    dims = len(prior_dict)
    means = np.zeros(dims)
    ranges = np.zeros((dims, 2))
    noise = np.zeros(dims)
    
    allowed = False
    for i, key in enumerate(sorted(prior_dict.keys())):
        means[i] = prior_dict[key]['params']['mean'][0]
        ranges[i] = prior_dict[key]['range']  
        noise[i] = prior_dict[key]['noise']
    
    while not allowed:
        if method == 'warping':
            param_scale = ranges[:, 1] - ranges[:, 0]
            move_direction = np.random.uniform(-1, 1, size=dims)
            movements = noise_level * move_direction / eucl(move_direction) * param_scale
            new_means = means + movements
         
        elif method == 'bopro':
            movements = noise * norm.rvs(size=dims)
            new_means = means + movements

        allowed = is_allowed(new_means, ranges)
    
    for i, key in enumerate(sorted(prior_dict.keys())):
        prior_dict[key]['params']['mean'] = [new_means[i]]
        
    return prior_dict


# This is a rather confusing one - it first calls define function and prior to get
# the right benchmark, parameters and priors
# Then, it defines that black box function wrapper around the benchmark. This is due
# to the emukit benchmarks taking a numpy array as input, and also returning a cost,
# which we're not interested in. 
# Lastly, we define the prior and the configuration space and return all relevant things
# to start_smac
def setup_experiment(function_id, function_family, method='bopro', prior_strength='strong', prior_index=0, smac_run=False):
    
    path_to_files = join(dirname(os.getcwd()), 'profet_data') # update here
    fcn, parameter_space, prior_config = define_function_and_prior(
        function_id, 
        function_family, 
        path_to_files, 
        method=method, 
        prior_strength=prior_strength, 
        prior_index=prior_index)
    
    # here, we chack the method - but do not act on the prior config, since that's rather inconsistent
    if method == 'warping':
        dist = WarpedSpace.create_distribution(prior_config)
        
        # To WarpedSpace, we pass the distribution created and the black-box function
        # to evaluate, we just do __call__ on warped space
        warped_space = WarpedSpace(dist, parameter_space.get_bounds(), fcn)
        def bb_function(config):
            config = config.get_dictionary()
            order = sorted(config.keys())
            X = np.array([config[order[i]] for i in range(len(order))]).reshape(1, -1)    
            result = warped_space(X)
            try:
                y, c = result # function value, cost for all functions except forrester
            except ValueError:
                y = result

            return float(y[0,0])
        
        cs = ConfigurationSpace()
        params = []
        for param_idx in range(len(parameter_space.get_bounds())):
            key = "x" + str(param_idx)
            x = UniformFloatHyperparameter(key, 0, 1, default_value=0.5)
            params.append(x)
        cs.add_hyperparameters(params)
        
        return cs, None, bb_function

    # if method is instead BOPrO - somewhat misleading
    else:

        number_of_parameters = 0
        cs = ConfigurationSpace()
        params = []
        param_priors = {}
        param_names = []
        
        # this includes the HPO benchmark and spatial
        # still unclear how the spatial stuff has to be formatted
        if type(parameter_space) == type(cs):
            def bb_function(config):
                result = fcn(config)
                try:
                    y, c = result # function value, cost for all functions except forrester
                except TypeError:
                    y = result
                #print(config)
                return float(y)

            # SMAC is going to optimize over the exact same parameter space
            # The tricky part is what we need to give the prior
            # but that is defined in the prior json and passed to process_prior
            cs = parameter_space


        else:
            # This is the regular BB definition for emukit and branin
            def bb_function(config):
                config = config.get_dictionary()
                order = sorted(config.keys())
                X = np.array([config[order[i]] for i in range(len(order))]).reshape(1, -1)    
                result = fcn(X)
                try:
                    y, c = result # function value, cost for all functions except forrester
                except ValueError:
                    y = result

                return float(y[0,0])

        # possibly need to first check parameter type, and then add
            for param_idx, (lower, upper) in enumerate(parameter_space.get_bounds()):
                #print(param_idx)
                number_of_parameters += 1
                key = "x" + str(param_idx)
                x = UniformFloatHyperparameter(key, lower, upper, default_value=(lower+upper) / 2)
                param_names.append(key)
                params.append(x)
            cs.add_hyperparameters(params)

        if prior_config is not None:
            names, types, pdfs, sampling_funcs, ranges, modes, values = process_prior(prior_config)              
            prior = Prior(cs, names, types, pdfs, sampling_funcs, ranges, modes, values)
            #print(f'Prior mode:\n {prior.mode}\n')
        
        else:
            prior = None
        
        return cs, prior, bb_function

def process_prior(prior_file):
    
    def from_array(X, probs=None, range_offset=None):
        # if there are no values to consider, only return the element in the order
        # as with categoricals
        if range_offset is None:
            return probs[np.round(X).astype(int)]
        
        # otherwise, consider the value of the element and return the probability
        # in that spot
        else:
            indices = np.round(X).astype(int) - range_offset
            return probs[np.round(X).astype(int) - range_offset]
    
    names = []
    types = []
    pdfs = []
    sampling_funcs = []
    ranges = []
    modes = []
    values = []
    # needed to not cause memory issue with saved categorical probability arrays
    # this is sorted to prior keys as SMAC does the same for its configspace
    for key in sorted(prior_file.keys()):
        names.append(key)
        #print('Prior input order: ', key)
        try:
            dist = prior_file[key]['dist']
        except KeyError:
            #print('Did not find distribution type, assuming gaussian')
            prior_file[key]['dist'] = 'gaussian'
        
        types.append(prior_file[key]['dist'])

        if prior_file[key]['dist'] == 'gaussian':
            values.append(None)
            mean = prior_file[key]['params']['mean'][0]
            std = prior_file[key]['params']['std'][0]
            ranges.append(prior_file[key]['range'])
            pdfs.append(norm(mean, std).pdf)
            sampling_funcs.append(norm(mean, std).rvs)
            modes.append(mean)
        
        # don't use the actual values, but an integer array with the same length
        # only used for sampling anyway  
        # give a range of 0,1 to prior so that it doesn't scale the input  
        elif prior_file[key]['dist'] == 'categorical':
            values.append(prior_file[key]['params']['values'])
            probs = np.array(prior_file[key]['params']['probs'])
            assert len(values[-1]) == len(probs),\
                f'{key} has values of length {len(values)}\nValues:{values}\nand probabilities of length{len(probs)}\nProbabilities{probs}'
                
            ranges.append([0, 1])
            modes.append(np.argmax(probs))
            pdfs.append(partial(from_array, probs=probs))
            sampling_funcs.append(partial(np.random.choice, len(probs), p=probs))
        
        # used for integer when you want to set the probablilities yourself 
        elif prior_file[key]['dist'] == 'integer':
            lower, upper = prior_file[key]['range'][0], prior_file[key]['range'][1]
            values.append(list(range(lower, upper+1)))
            probs = np.array(prior_file[key]['params']['probs'])
            assert len(values[-1]) == len(probs),\
                f'{key} has values of length {len(values)}\nValues:{values}\nand probabilities of length{len(probs)}\nProbabilities{probs}'
            ranges.append([lower, upper])
            modes.append(np.argmax(probs) + lower) 
            
            # need the scaling of the integer parameter
            # and the adjustment to the range
            # pdfs.append(lambda x: probs[np.round(x).astype(int)])
            pdfs.append(partial(from_array, probs=probs, range_offset=ranges[-1][0]))
            # and for the sampling, too
            sampling_funcs.append(partial(np.random.choice, len(probs), p=probs))

    ranges = np.array(ranges)
    return names, types, pdfs, sampling_funcs, ranges, modes, values


