import functools

import torch
from cocoex.interface import Problem

from algorithms.space.base_space import IdentifiableSpace
from algorithms.space.basic_callable_space import BasicCallableSpace
from algorithms.space.utils import coco_problem_from_funcnum_and_dim, is_multiple_points
from problems.types import Suites

PROBLEM_NAMES = {
    1: "Sphere",
    2: "Ellipsoid separable",
    3: "Rastrigin separable",
    4: "Skew Rastrigin-Bueche separ",
    5: "Linear slope",
    6: "Attractive sector",
    7: "Step-ellipsoid",
    8: "Rosenbrock original",
    9: "Rosenbrock rotated",
    10: "Ellipsoid",
    11: "Discus",
    12: "Bent cigar",
    13: "Sharp ridge",
    14: "Sum of different powers",
    15: "Rastrigin",
    17: "Schaffer F7, condition 10",
    18: "Schaffer F7, condition 1000",
    19: "Griewank-Rosenbrock F8F2",
    20: "Schwefel x*sin(x)",
    21: "Gallagher 101 peaks",
    22: "Gallagher 21 peaks",
    23: "ats ras",
    24: "Lunacek bi-Rastrigin",
}


class CocoSpace(BasicCallableSpace, IdentifiableSpace):
    def __init__(self, coco_func: Problem, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.coco_func = coco_func

    def is_goal_reached(self):
        return self.coco_func.final_target_hit

    @property
    def suite(self) -> Suites:
        return Suites.COCO

    @property
    def func_id(self) -> int:
        return self.coco_func.id_function

    @property
    def func_instance(self) -> int:
        return self.coco_func.id_instance

    @property
    def callable_env(self):
        def func(data):
            data = data.detach()
            if is_multiple_points(data):
                calculated_points = [self.coco_func(point) for point in data.cpu()]
            else:
                calculated_points = self.coco_func(data.cpu())
            return torch.tensor(calculated_points, device=data.device, dtype=data.dtype)
        return func

    @functools.cached_property
    def upper_bound(self):
        return torch.from_numpy(self.coco_func.upper_bounds)

    @functools.cached_property
    def lower_bound(self):
        return torch.from_numpy(self.coco_func.lower_bounds)

    def __getstate__(self):
        return {
            "func_id": self.coco_func.id_function,
            "func_dim": self.coco_func.dimension,
            "func_instance": self.coco_func.id_instance,
            "budget": self.budget,
        }

    def __setstate__(self, state):
        self.coco_func = coco_problem_from_funcnum_and_dim(
            state["func_id"], state["func_dim"], state["func_instance"]
        )
        self.budget = state["budget"]
        self.sample_count = 0

    def __str__(self):
        return f"coco {self.__repr__()}, remaining budget: {self.num_of_samples}"

    def __repr__(self):
        return f"{self.coco_func.id_function}-{self.coco_func.dimension}-{self.coco_func.id_instance}"

    @property
    def name(self) -> str:
        return PROBLEM_NAMES.get(self.func_id, "")
