from .md import MD


class PMD(MD):
    def __init__(
        self,
        strategy_space,
        regularizer,
        learning_rate,
        perturbation_strength,
        update_anchoring_interval,
        **kwargs,
    ):
        super().__init__(strategy_space, regularizer, learning_rate, **kwargs)
        self.perturbation_strength = perturbation_strength
        self.anchoring_strategy = self.strategy.copy()
        self.update_anchoring_interval = update_anchoring_interval

    def name(self):
        alg_name = self.__class__.__name__
        if self.update_anchoring_interval is not None:
            alg_name += "_tsig{}".format(self.update_anchoring_interval)
        alg_name += "_mu{}".format(self.perturbation_strength)
        alg_name += "_lr{}".format(self.learning_rate)
        alg_name += "_{}".format(self.regularizer)
        return alg_name

    def add_gradient(self, gradient):
        perturbation = -self.perturbation_strength * (
            self.strategy - self.anchoring_strategy
        )
        self.cum_gradient += gradient + perturbation
        self.gradient = gradient + perturbation
        if (
            self.update_anchoring_interval is not None
            and self.n % self.update_anchoring_interval == 0
        ):
            self.anchoring_strategy = self.strategy.copy()

    def _gradient_of_perturbed_utility(self, strategy):
        return -self.perturbation_strength * (strategy - self.anchoring_strategy)
