import numpy as np

from _test_functions.objective_function import ObjectiveFunction
from _test_functions.mazda_cars.mazda_cars import BaseMazdaCarsOptimization
from _test_functions.reproblems.reproblem import RE21 as Source_FourBarTruss
from _test_functions.reproblems.reproblem import RE22 as Source_ReinforcedConcreteBeam
from _test_functions.reproblems.reproblem import RE23 as Source_PressureVessel
from _test_functions.reproblems.reproblem import RE24 as Source_HatchCover
from _test_functions.reproblems.reproblem import RE25 as Source_CoilCompressionSpring
from _test_functions.reproblems.reproblem import RE31 as Source_TwoBarTruss
from _test_functions.reproblems.reproblem import RE32 as Source_WeldedBeam
from _test_functions.reproblems.reproblem import RE33 as Source_DiscBrake
from _test_functions.reproblems.reproblem import RE34 as Source_VehicleDesign
from _test_functions.reproblems.reproblem import RE35 as Source_SpeedReducer
from _test_functions.reproblems.reproblem import RE36 as Source_GearTrain
from _test_functions.reproblems.reproblem import RE37 as Source_RocketInjector
from _test_functions.reproblems.reproblem import RE41 as Source_CarSideImpact
from _test_functions.reproblems.reproblem import RE42 as Source_ConceptualMarineDesign
from _test_functions.reproblems.reproblem import RE61 as Source_WaterResourcePlanning
from _test_functions.reproblems.reproblem import RE91 as Source_CarCabDesign
from _test_functions.estimate_ref_point import ref_points

class BaseREProblem(ObjectiveFunction):
    def __init__(self, base_name, core_function, **kwargs):
        super().__init__(input_dims=core_function.n_variables, num_objectives=core_function.n_objectives, num_constraints=core_function.n_constraints)
        self.core_function = core_function
        self.num_constraints = 0
        self.__base_bounds = np.hstack((self.core_function.lbound.reshape(-1, 1), self.core_function.ubound.reshape(-1, 1)))
        self.bounds = np.hstack((np.zeros((self.core_function.n_variables, 1)), np.ones((self.core_function.n_variables, 1))))
        self.input_dims = self.core_function.n_variables
        self.num_objectives = self.core_function.n_objectives
        self.ref_point = ref_points.get(base_name, None)
        assert self.ref_point is not None
        self.name = f'{base_name}'

    def evaluate_objectives(self, X_batch):
        assert X_batch.ndim == 2 and X_batch.shape[1] == self.input_dims
        assert np.all(X_batch <= self.bounds[:, 1]) and np.all(X_batch >= self.bounds[:, 0])
        fs = []
        for X_1d in X_batch:
            f = self.evaluate_objectives_single(X_1d)
            fs.append(f)
        return np.array(fs)
    
    def evaluate_objectives_single(self, X_1d):
        assert X_1d.ndim == 1 and X_1d.shape[0] == self.input_dims
        X_1d_scaled = self.__base_bounds[:, 0] + (self.__base_bounds[:, 1] - self.__base_bounds[:, 0]) * X_1d
        assert np.all(X_1d_scaled <= self.core_function.ubound) and np.all(X_1d_scaled >= self.core_function.lbound)
        fs = self.core_function.evaluate(X_1d_scaled)
        return fs
    
    def __call__(self, X):
        assert X.ndim == 2 and X.shape[1] == self.input_dims
        assert np.all(X <= self.bounds[:, 1]) and np.all(X >= self.bounds[:, 0])
        return (self.evaluate_objectives(X), None)
    
class BaseFourBarTruss(BaseREProblem):
    def __init__(self, **kwargs):
        super().__init__('four-bar-truss', Source_FourBarTruss(), **kwargs)

class BaseReinforcedConcreteBeam(BaseREProblem):
    def __init__(self, **kwargs):
        super().__init__('concrete-beam', Source_ReinforcedConcreteBeam(), **kwargs)

class BasePressureVessel(BaseREProblem):
    def __init__(self, **kwargs):
        super().__init__('pressure-vessel', Source_PressureVessel(), **kwargs)

class BaseHatchCover(BaseREProblem):
    def __init__(self, **kwargs):
        super().__init__('hatch-cover', Source_HatchCover(), **kwargs)

class BaseCoilCompressionSpring(BaseREProblem):
    def __init__(self, **kwargs):
        super().__init__('coil-spring', Source_CoilCompressionSpring(), **kwargs)

class BaseTwoBarTruss(BaseREProblem):
    def __init__(self, **kwargs):
        super().__init__('two-bar-truss', Source_TwoBarTruss(), **kwargs)

class BaseWeldedBeam(BaseREProblem):
    def __init__(self, **kwargs):
        super().__init__('welded-beam', Source_WeldedBeam(), **kwargs)

class BaseDiscBrake(BaseREProblem):
    def __init__(self, **kwargs):
        super().__init__('disc-brake', Source_DiscBrake(), **kwargs)

class BaseVehicleDesign(BaseREProblem):
    def __init__(self, **kwargs):
        super().__init__('vehicle-design', Source_VehicleDesign(), **kwargs)

class BaseSpeedReducer(BaseREProblem):
    def __init__(self, **kwargs):
        super().__init__('speed-reducer', Source_SpeedReducer(), **kwargs)

class BaseGearTrain(BaseREProblem):
    def __init__(self, **kwargs):
        super().__init__('gear-train', Source_GearTrain(), **kwargs)

class BaseRocketInjector(BaseREProblem):
    def __init__(self, **kwargs):
        super().__init__('rocket-injector', Source_RocketInjector(), **kwargs)


class BaseCarSideImpact(BaseREProblem):
    def __init__(self, **kwargs):
        super().__init__('car-impact', Source_CarSideImpact(), **kwargs)

class BaseConceptualMarineDesign(BaseREProblem):
    def __init__(self, **kwargs):
        super().__init__('marine-design', Source_ConceptualMarineDesign(), **kwargs)

class BaseWaterResourcePlanning(BaseREProblem):
    def __init__(self, **kwargs):
        super().__init__('water-planning', Source_WaterResourcePlanning(), **kwargs)

class BaseCarCabDesign(BaseREProblem):
    def __init__(self, **kwargs):
        super().__init__('car-cab-design', Source_CarCabDesign(), **kwargs)

if __name__ == '__main__':

    #four-bar-truss concrete-beam pressure-vessel hatch-cover coil-spring two-bar-truss welded-beam disc-brake vehicle-design speed-reducer gear-train rocket-injector car-impact marine-design water-planning car-cab-design
    ...