from typing import Callable

import torch
from torch import Tensor

from algorithms.space.basic_callable_space import BasicCallableSpace
from problems.types import Suites, PYTHONIC_SUITE


class CallableSpace(BasicCallableSpace):
    def __init__(
        self,
        func: Callable,
        input_lower_bounds: Tensor,
        input_upper_bounds: Tensor,
        **kwargs,
    ):
        super(CallableSpace, self).__init__(**kwargs)
        self.func = func
        self.input_upper_bounds = input_upper_bounds
        self.input_lower_bounds = input_lower_bounds

    def is_goal_reached(self):
        return False

    @property
    def callable_env(self):
        def func(data):
            return torch.from_numpy(
                self.func(data.numpy() if isinstance(data, Tensor) else data)
            ).to(
                device=data.device,
                dtype=data.dtype,
            )

        return func

    @property
    def suite(self) -> Suites:
        return Suites.PYTHONIC

    @property
    def func_id(self) -> int:
        return PYTHONIC_SUITE.index(type(self.func))

    @property
    def func_instance(self) -> int:
        return 0

    def __repr__(self):
        return f"{type(self.func).__name__}-{self.dimension}"

    def __str__(self):
        return f"{self.__repr__()}, remaining budget: {self.num_of_samples}"

    @property
    def upper_bound(self):
        return self.input_upper_bounds

    @property
    def lower_bound(self):
        return self.input_lower_bounds
