import copy

from algorithms.convergence_algorithms.egl import EGL


class EGLRestart(EGL):
    ALGORITHM_NAME = "egl_restart"

    def __init__(self, *args, max_budget: int = 30_000, **kwargs):
        super().__init__(*args, **kwargs)
        self.epsilon_start = self.epsilon
        self.max_budget = max_budget
        self.second_tr = copy.deepcopy(self.input_mapping)
        self.second_tr.dim_proportion_factor = 2

    def before_shrinking_hook(self):
        if self.max_budget <= self.env.used_budget:
            self.logger.info("Restarting tr")
            self.epsilon = self.epsilon_start
            self.input_mapping = self.second_tr
            self.max_budget *= 2
