from localglobal.bo.localbo_utils import ordinal2onehot, onehot2ordinal
from localglobal.test_funcs.base import TestFunction
from copy import deepcopy
import numpy as np
from utils import get_dim_info


class ModifiedObjectiveFunc:
    """A modified function class to be used for optimisers that do not make special treatment for the
    categorical variables."""
    def __init__(self, f: TestFunction):
        self.problem_type = f.problem_type
        self.f = f
        if self.problem_type != 'continuous':
            self.one_hot_dims = np.sum(f.n_vertices)
            if hasattr(f, 'lb'):
                self.lb = np.hstack((np.zeros(self.one_hot_dims), f.lb))
            else:
                self.lb = np.zeros(self.one_hot_dims)
            if hasattr(f, 'ub'):
                self.ub = np.hstack((np.ones(self.one_hot_dims), f.ub))
            else:
                self.ub = np.ones(self.one_hot_dims)

    def __call__(self, X, **kwargs):
        if X.ndim == 1:
            X = X.reshape(1, -1)
        if self.problem_type == 'categorical':
            X_ = onehot2ordinal(X, get_dim_info(self.f.config))
        elif self.problem_type == 'mixed':
            X_cont = X[:, self.one_hot_dims:]
            X_cat = onehot2ordinal(X[:, :self.one_hot_dims], get_dim_info(self.f.config))
            X_ = np.hstack((X_cat, X_cont))
        elif self.problem_type == 'continuous':
            X_ = X
        else:
            raise ValueError(self.problem_type)
        return self.f(X_)

    def __getattr__(self, item):
        try:
            attr = getattr(self.f, item)
            return attr
        except AttributeError:
            raise AttributeError(item + ' is not a valid attribute for either the modified obj class or the class '
                                        'derived.')

    def parse_bound(self):
        """Parse the bound into a format understood by Ax for experiments using HeSBO and ALEBO"""
        params = []
        for dim in range(int(self.one_hot_dims)):
            current_param = {}
            current_param.update({'name': f'x_{dim}'})
            current_param.update({'type': 'range'})
            current_param.update({'bounds': [self.lb[dim], self.ub[dim]]})
            current_param.update({'value_type': 'float'})
            params.append(current_param)
        return params
