import abc
import math

import torch
from torch import Tensor

from algorithms.space.base_space import (
    BudgetLimitedSpace,
    SamplerSpace,
    EvaluatedSpace,
    BoundedSpace,
    IdentifiableSpace,
)
from algorithms.space.exceptions import NoMoreBudgetError
from algorithms.space.utils import best_k_by_values


class BasicCallableSpace(
    BudgetLimitedSpace, SamplerSpace, EvaluatedSpace, BoundedSpace, IdentifiableSpace
):
    def __init__(self, budget: int = math.inf):
        self.budget = budget
        self.sample_count = 0

    @property
    @abc.abstractmethod
    def callable_env(self):
        raise NotImplementedError()

    @property
    def dimension(self):
        return len(self.lower_bound)

    def __call__(self, data_from_input_space: Tensor, debug_mode: bool = False) -> Tensor:
        original_shape = data_from_input_space.shape
        value_shape = list(original_shape)
        value_shape = value_shape[:-1] if len(value_shape) > 1 else [1]
        data_in_single_batch = data_from_input_space.reshape(-1, self.dimension)
        data_in_single_batch = self.__preprocess_call_data(data_in_single_batch)
        if not debug_mode:
            self._pre_evaluation_action(data_in_single_batch)
        return self.__evaluate_env(data_in_single_batch).reshape(value_shape)

    def initialize(self):
        pass

    @property
    def num_of_samples(self):
        return self.sample_count

    def __preprocess_call_data(self, data):
        return data.clip(
            self.lower_bound.to(device=data.device, dtype=data.dtype),
            self.upper_bound.to(device=data.device, dtype=data.dtype),
        )

    def __evaluate_env(self, data):
        return self.callable_env(data).to(dtype=data.dtype)

    @property
    def total_budget(self) -> int:
        return self.budget

    @property
    def used_budget(self) -> int:
        return self.num_of_samples

    def free_to_check_data(self, device: int) -> Tensor:
        raise NotImplementedError(
            "From now on, the algorithm should be the one to remember this"
        )

    def sample_for_optimum_points(
        self, budget: int, num_of_samples: int, device: int = None
    ) -> Tensor:
        samples = {point: self.callable_env(point) for point in self.sample_from_space(budget)}
        min_point_value = sorted(samples.items(), key=lambda x: x[1])[:num_of_samples]
        return torch.tensor([list(point[0]) for point in min_point_value], device=device)

    def sample_from_space(
        self, num_samples: int, dtype: torch.dtype = torch.float64, device: int = None
    ) -> Tensor:
        size_of_possible_input = (self.upper_bound - self.lower_bound).to(device=device)
        return (
            torch.rand(num_samples, self.dimension, device=device) * size_of_possible_input
        ) + self.lower_bound.to(device=device, dtype=dtype)

    def best_k_values(self, input_in_space: Tensor, k: int, debug_mode: bool = False) -> Tensor:
        new_x_opt_env_value = self(input_in_space, debug_mode=debug_mode)
        return best_k_by_values(input_in_space, new_x_opt_env_value, k)

    def best_k_indices(
        self, input_in_space: Tensor, k: int, debug_mode: bool = False
    ) -> Tensor:
        new_x_opt_env_value = self(input_in_space, debug_mode=debug_mode)
        return (-new_x_opt_env_value).topk(k).indices

    def _pre_evaluation_action(self, data: Tensor):
        self.sample_count += len(data)
        if self.num_of_samples > self.budget:
            raise NoMoreBudgetError(f"Exceeded budget of f{self.budget}")

    def denormalize(self, data: Tensor) -> Tensor:
        return 0.5 * (data + 1) * (self.upper_bound - self.lower_bound).to(
            device=data.device, dtype=data.dtype
        ) + self.lower_bound.to(device=data.device, dtype=data.dtype)

    def normalize(self, data: Tensor) -> Tensor:
        return (
            (data - self.lower_bound.to(device=data.device, dtype=data.dtype))
            / (self.upper_bound - self.lower_bound).to(device=data.device, dtype=data.dtype)
            * 2
        ) - 1

    @property
    def device(self):
        return None
