import functools

from torch import Tensor

from algorithms.convergence_algorithms.egl import EGL
from algorithms.stopping_condition.quadratic_model import QuadModelMinimumReach


class EGLEpsilonEnlarger(EGL):
    ALGORITHM_NAME = "ee_egl"

    def __init__(
        self,
        *args,
        min_distance_from_model_min: float = 5e-1,
        min_tr_radius: float = 1,
        enlarge_count: int = 3,
        should_enlarge: bool = True,
        too_close_to_min: int = None,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.min_distance_from_model_min = min_distance_from_model_min
        self.min_tr_radius = min_tr_radius
        self.enlarge_count = enlarge_count
        self.too_close_to_min = too_close_to_min
        self.should_enlarge = should_enlarge

    def set_start_point(self, start_point: Tensor):
        raise NotImplementedError()

    @functools.cached_property
    def space_diagonal_distance(self):
        return (self.env.lower_bound - self.env.upper_bound).pow(2).sum().sqrt()

    @functools.cached_property
    def max_tr_radius(self):
        return 1 * (self.input_mapping.shrink_factor ** self.enlarge_count)

    def after_shrinking_hook(self):
        mean_radius = self.input_mapping.sigma.mean()
        tr_radius = mean_radius * self.space_diagonal_distance
        quad_model_estimator = QuadModelMinimumReach(tr_radius, 5e-1, self.logger)
        real_distance_from_model_min = quad_model_estimator.distance_from_model_min(self)
        unreal_distance_from_model_min = quad_model_estimator.distance_from_model_min(
            self, real=False
        )
        self.logger.info(
            f"Distance from min {unreal_distance_from_model_min}, {real_distance_from_model_min}, "
            f"mean radius {mean_radius}"
        )

        if (
            unreal_distance_from_model_min > self.epsilon
            and self.should_enlarge
            and mean_radius < self.max_tr_radius
            and mean_radius < self.min_tr_radius
        ):
            new_eps = self.epsilon / (self.epsilon_factor ** self.enlarge_count)
            self.logger.info(
                f"Distance from model min is too large, enlarge epsilon to {new_eps} from {self.epsilon}"
            )
            self.epsilon = new_eps

            enlarge_count = (
                int(unreal_distance_from_model_min / self.input_mapping.shrink_factor)
                + self.enlarge_count
            )
            for i in range(enlarge_count):
                self.input_mapping.unsqueeze()
                if self.input_mapping.sigma.mean() > 1:
                    break
            self.logger.info(f"Enlarge tr {self.input_mapping}, {self.enlarge_count} times")
        elif self.too_close_to_min and unreal_distance_from_model_min < self.too_close_to_min:
            self.input_mapping.squeeze(self.input_mapping.mu)
            self.epsilon *= self.epsilon_factor
            self.epsilon = max(self.epsilon, self.min_epsilon)
            self.logger.info(
                f"The point is too close to evaluated min, shrinking again {self.input_mapping} "
                f"epsilon {self.epsilon}"
            )
