from torch import Tensor

from algorithms.convergence_algorithms.base import Configurable


class InputMapping(Configurable):
    def map(self, tensor: Tensor) -> Tensor:
        raise NotImplementedError()

    def inverse(self, tensor: Tensor) -> Tensor:
        raise NotImplementedError()

    def squeeze(self, best_result: Tensor, **kwargs):
        pass

    def move_center(self, best_result: Tensor):
        pass

    def unsqueeze(self):
        pass

    def sample_from_unbounded(self, sample_size: int, device: int = None):
        raise NotImplementedError()

    def stop_condition(self) -> bool:
        return False


class OutputMapping(Configurable):
    def map(self, tensor: Tensor) -> Tensor:
        raise NotImplementedError()

    def inverse(self, tensor: Tensor) -> Tensor:
        raise NotImplementedError()

    def adapt(self, new_data: Tensor):
        pass


class DefaultMapping(OutputMapping):
    def map(self, tensor: Tensor):
        return tensor

    def inverse(self, tensor: Tensor):
        return tensor
