import torch
from torch import Tensor

from algorithms.space.callable_function import CallableSpace


class OptimizationTestSpace(CallableSpace):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.has_reached_goal = False

    def __call__(self, data_from_input_space: Tensor, debug_mode: bool = False) -> Tensor:
        if not self.has_reached_goal:
            self.has_reached_goal = (
                (data_from_input_space == torch.from_numpy(self.func.x_best)).all()
                if hasattr(self.func, "x_best") and self.func.x_best is not None
                else False
            )
        return super().__call__(data_from_input_space, debug_mode)

    def is_goal_reached(self):
        return self.has_reached_goal

    @property
    def name(self) -> str:
        return self.func.__class__.__name__
