import math
from copy import deepcopy

import numpy as np
from _test_functions.objective_function import ObjectiveFunction
from _test_functions.estimate_ref_point import ref_points
from _test_functions.synthetic.dtlz import DTLZ1 as PymooDTLZ1
from _test_functions.synthetic.dtlz import DTLZ2 as PymooDTLZ2
from _test_functions.synthetic.dtlz import DTLZ3 as PymooDTLZ3
from _test_functions.synthetic.dtlz import DTLZ4 as PymooDTLZ4
from _test_functions.synthetic.dtlz import DTLZ7 as PymooDTLZ7
from _test_functions.synthetic.oka import OKA1, OKA2
from _test_functions.synthetic.zdt import ZDT1 as PymooZDT1
from _test_functions.synthetic.zdt import ZDT2 as PymooZDT2
from _test_functions.synthetic.zdt import ZDT3 as PymooZDT3
from _test_functions.synthetic.vlmop import VLMOP2, VLMOP3


class BaseDTLZ(ObjectiveFunction):
    def __init__(self, base_name, core_function):
        super().__init__(input_dims=core_function.n_var, num_objectives=core_function.n_obj, num_constraints=0)
        self.core_function = core_function
        self.num_constraints = 0
        self.bounds = np.array([(self.core_function.xl, self.core_function.xu)]*self.core_function.n_var)
        self.input_dims = self.core_function.n_var
        self.num_objectives = self.core_function.n_obj
        self.ref_point = np.array([ref_points.get(base_name, None)]*self.num_objectives)
        self.name = f'{base_name}-d{self.input_dims}-m{self.num_objectives}'

    def evaluate_objectives(self, X_input):
        X = deepcopy(X_input)
        if not isinstance(X, np.ndarray):
            raise ValueError('Input type not supported')
        assert X.shape[-1] == self.input_dims, f'X is {X.shape}, expected {self.input_dims}'
        input_ndim = X.ndim
        fx = self.core_function(np.atleast_2d(X))
        if input_ndim == 1:
            fx = np.squeeze(fx)
        return fx
    
    def __call__(self, X):
        return (self.evaluate_objectives(X), None)

class BaseDTLZ1(BaseDTLZ):
    def __init__(self, dim, num_objectives):
        super().__init__('dtlz1', PymooDTLZ1(n_var=dim, n_obj=num_objectives))

class BaseDTLZ2(BaseDTLZ):
    def __init__(self, dim, num_objectives):
        super().__init__('dtlz2', PymooDTLZ2(n_var=dim, n_obj=num_objectives))

class BaseDTLZ3(BaseDTLZ):
    def __init__(self, dim, num_objectives):
        super().__init__('dtlz3', PymooDTLZ3(n_var=dim, n_obj=num_objectives))

class BaseDTLZ4(BaseDTLZ):
    def __init__(self, dim, num_objectives):
        super().__init__('dtlz4', PymooDTLZ4(n_var=dim, n_obj=num_objectives))

class BaseDTLZ7(BaseDTLZ):
    def __init__(self, dim, num_objectives):
        super().__init__('dtlz7', PymooDTLZ7(n_var=dim, n_obj=num_objectives))

class BaseZDT(ObjectiveFunction):
    def __init__(self, base_name, core_function):
        super().__init__(input_dims=core_function.n_var, num_objectives=core_function.n_obj, num_constraints=0)
        self.core_function = core_function
        self.num_constraints = 0
        self.bounds = np.array([(self.core_function.xl, self.core_function.xu)]*self.core_function.n_var)
        self.input_dims = self.core_function.n_var
        self.num_objectives = self.core_function.n_obj
        self.ref_point = np.array([ref_points.get(base_name, None)]*self.num_objectives)
        self.name = f'{base_name}-d{self.input_dims}-m{self.num_objectives}'

    def evaluate_objectives(self, X_input):
        X = deepcopy(X_input)
        if not isinstance(X, np.ndarray):
            raise ValueError('Input type not supported')
        assert X.shape[-1] == self.input_dims, f'X is {X.shape}, expected {self.input_dims}'
        input_ndim = X.ndim
        fx = self.core_function(np.atleast_2d(X))
        if input_ndim == 1:
            fx = np.squeeze(fx)
        return fx
    
    def __call__(self, X):
        return (self.evaluate_objectives(X), None)
    
class BaseZDT1(BaseZDT):
    def __init__(self, dim, **kwargs):
        super().__init__('zdt1', PymooZDT1(n_var=dim))

class BaseZDT2(BaseZDT):
    def __init__(self, dim, **kwargs):
        super().__init__('zdt2', PymooZDT2(n_var=dim))

class BaseZDT3(BaseZDT):
    def __init__(self, dim, **kwargs):
        super().__init__('zdt3', PymooZDT3(n_var=dim))

class BaseOKA1(OKA1):
    def __init__(self):
        super().__init__('oka1', n_var=2, n_obj=2)

class BaseOKA(ObjectiveFunction):
    def __init__(self, base_name, core_function):
        super().__init__(input_dims=core_function.n_var, num_objectives=core_function.n_obj, num_constraints=0)
        self.core_function = core_function
        self.num_constraints = 0
        self.bounds = np.array([(self.core_function.xl, self.core_function.xu)]*self.core_function.n_var)
        self.input_dims = self.core_function.n_var
        self.num_objectives = self.core_function.n_obj
        self.ref_point = np.array(ref_points.get(base_name, None))
        self.name = f'{base_name}'

    def evaluate_objectives(self, X_input):
        X = deepcopy(X_input)
        if not isinstance(X, np.ndarray):
            raise ValueError('Input type not supported')
        assert X.shape[-1] == self.input_dims, f'X is {X.shape}, expected {self.input_dims}'
        input_ndim = X.ndim
        fx = self.core_function(np.atleast_2d(X))
        if input_ndim == 1:
            fx = np.squeeze(fx)
        return fx
    
    def __call__(self, X):
        return (self.evaluate_objectives(X), None)

class BaseOKA1(BaseOKA):
    def __init__(self, **kwargs):
        super().__init__('oka1', OKA1())

class BaseOKA2(BaseOKA):
    def __init__(self, **kwargs):
        super().__init__('oka2', OKA2())

class BaseVLMOP(ObjectiveFunction):
    def __init__(self, base_name, core_function):
        super().__init__(input_dims=core_function.n_var, num_objectives=core_function.n_obj, num_constraints=0)
        self.core_function = core_function
        self.num_constraints = 0
        self.bounds = np.array([(self.core_function.xl, self.core_function.xu)]*self.core_function.n_var)
        self.input_dims = self.core_function.n_var
        self.num_objectives = self.core_function.n_obj
        self.name = f'{base_name}'

    def evaluate_objectives(self, X_input):
        X = deepcopy(X_input)
        if not isinstance(X, np.ndarray):
            raise ValueError('Input type not supported')
        assert X.shape[-1] == self.input_dims, f'X is {X.shape}, expected {self.input_dims}'
        input_ndim = X.ndim
        fx = self.core_function(np.atleast_2d(X))
        if input_ndim == 1:
            fx = np.squeeze(fx)
        return fx
    
    def __call__(self, X):
        return (self.evaluate_objectives(X), None)
    
class BaseVLMOP2(BaseVLMOP):
    def __init__(self, dim, **kwargs):
        super().__init__('vlmop2', VLMOP2(n_var=dim))
        self.ref_point = np.array([ref_points.get('vlmop2', None)]*self.num_objectives)

class BaseVLMOP3(BaseVLMOP):
    def __init__(self, **kwargs):
        super().__init__('vlmop3', VLMOP3())
        self.ref_point = np.array(ref_points.get('vlmop3', None))

class BaseBraninCurrin(ObjectiveFunction):
    r"""Two objective problem composed of the Branin and Currin functions.

    Branin (rescaled):

        f(x) = (
        15*x_1 - 5.1 * (15 * x_0 - 5) ** 2 / (4 * pi ** 2) + 5 * (15 * x_0 - 5)
        / pi - 5
        ) ** 2 + (10 - 10 / (8 * pi)) * cos(15 * x_0 - 5))

    Currin:

        f(x) = (1 - exp(-1 / (2 * x_1))) * (
        2300 * x_0 ** 3 + 1900 * x_0 ** 2 + 2092 * x_0 + 60
        ) / 100 * x_0 ** 3 + 500 * x_0 ** 2 + 4 * x_0 + 20

    """

    def __init__(self, **kwargs) -> None:
        super().__init__(input_dims=2, num_objectives=2, num_constraints=0)
        self.num_constraints = 0
        self.bounds = np.array([(0.0, 1.0), (0.0, 1.0)])
        self.input_dims = 2
        self.num_objectives = 2
        base_name = 'branin-currin'
        self.ref_point = np.array(ref_points.get(base_name, None))
        self.name = f'{base_name}'

    def _rescaled_branin(self, X):
        # return to Branin bounds
        x_0 = 15 * X[:, 0] - 5
        x_1 = 15 * X[:, 1]
        return self._branin(np.stack([x_0, x_1], axis=-1))
    
    def _branin(self, X):
        t1 = (
            X[:, 1]
            - 5.1 / (4 * math.pi**2) * X[:, 0] ** 2
            + 5 / math.pi * X[:, 0]
            - 6
        )
        t2 = 10 * (1 - 1 / (8 * math.pi)) * np.cos(X[:, 0])
        return t1**2 + t2 + 10

    @staticmethod
    def _currin(X):
        x_0 = X[:, 0]
        x_1 = X[:, 1]
        factor1 = 1 - np.exp(-1 / (2 * x_1))
        numer = 2300 * np.power(x_0, 3) + 1900 * np.power(x_0, 2) + 2092 * x_0 + 60
        denom = 100 * np.power(x_0, 3) + 500 * np.power(x_0, 2) + 4 * x_0 + 20
        return factor1 * numer / denom

    def evaluate_objectives(self, X):
        # branin rescaled with inputsto [0,1]^2
        _X = np.atleast_2d(X)
        branin = self._rescaled_branin(X=_X)
        currin = self._currin(X=_X)
        input_ndim = X.ndim
        fx = np.stack([branin, currin],axis=-1)
        if input_ndim == 1:
            fx = np.squeeze(fx)
        return fx
    
    def __call__(self, X):
        return (self.evaluate_objectives(X), None)

if __name__ == '__main__':
    objective_name = 'dtlz2'
    dim = 10
    num_objectives = 3
    function = BaseZDT1(dim=dim, num_objectives=num_objectives)
    X = np.random.rand(1, dim)
    print(function(X))