
import numpy as np

class OKA:
    def __init__(self, n_var, n_obj, oka_xl, oka_xu, **kwargs):
        self.n_var = n_var
        self.n_obj = 2
        self.xl = 0.0
        self.xu = 1.0
        self.oka_xl = oka_xl
        self.oka_xu = oka_xu

    def _evaluate(self, x, *args, **kwargs):
        raise NotImplementedError()
    
    def __call__(self, x, *args, **kwargs):
        assert x.ndim == 2
        assert x.shape[1] == self.n_var
        assert np.all(x >= self.xl) and np.all(x <= self.xu)
        # scale from [0, 1] to [oka_xl, oka_xu]
        x_scaled = self.oka_xl + (self.oka_xu - self.oka_xl) * x
        fx = self._evaluate(x_scaled, *args, **kwargs)
        return fx

class OKA1(OKA):
    '''
    Okabe, Tatsuya, et al. "On test functions for evolutionary multi-objective optimization." International Conference on Parallel Problem Solving from Nature. Springer, Berlin, Heidelberg, 2004.
    '''
    def __init__(self, **kwargs):
        sin, cos = np.sin(np.pi / 12), np.cos(np.pi / 12)
        oka_xl = np.array([6 * sin, -2 * np.pi * sin])
        oka_xu = np.array([6 * sin + 2 * np.pi * cos, 6 * cos])
        super().__init__(n_var=2, n_obj=2, oka_xl=oka_xl, oka_xu=oka_xu)

    def _evaluate(self, x):
        assert x.ndim == 2
        assert x.shape[1] == self.n_var
        assert np.all(x >= self.oka_xl) and np.all(x <= self.oka_xu)
        sin, cos = np.sin(np.pi / 12), np.cos(np.pi / 12)
        x1, x2 = x[:, 0], x[:, 1]
        x1_ = cos * x1 - sin * x2
        x2_ = sin * x1 + cos * x2

        f1 = x1_
        f2 = np.sqrt(2 * np.pi) - np.sqrt(np.abs(x1_)) + 2 * np.abs(x2_ - 3 * np.cos(x1_) - 3) ** (1. / 3)

        return np.column_stack([f1, f2])


class OKA2(OKA):
    '''
    Okabe, Tatsuya, et al. "On test functions for evolutionary multi-objective optimization." International Conference on Parallel Problem Solving from Nature. Springer, Berlin, Heidelberg, 2004.
    '''
    def __init__(self, **kwargs):
        oka_xl = np.array([-np.pi, -5.0, -5.0])
        oka_xu = np.array([np.pi, 5.0, 5.0])
        super().__init__(n_var=3, n_obj=2, oka_xl=oka_xl, oka_xu=oka_xu)
    
    def _evaluate(self, x):
        assert x.ndim == 2
        assert x.shape[1] == self.n_var
        assert np.all(x >= self.oka_xl) and np.all(x <= self.oka_xu)
        x1, x2, x3 = x[:, 0], x[:, 1], x[:, 2]

        f1 = x1
        f2 = 1 - (x1 + np.pi) ** 2 / (4 * np.pi ** 2) + \
            np.abs(x2 - 5 * np.cos(x1)) ** (1. / 3) + np.abs(x3 - 5 * np.sin(x1)) ** (1. / 3)

        return np.column_stack([f1, f2])
