import numpy as np

from .md import MD


class OMD(MD):
    def __init__(self, strategy_space, regularizer, learning_rate, **kwargs):
        super().__init__(strategy_space, regularizer, learning_rate, **kwargs)
        self.strategy_hat = self.strategy.copy()

    def _calc_gradient(self):
        return self.cum_gradient, self.gradient

    def calc_strategy(self):
        learning_rate = self.learning_rate if isinstance(self.learning_rate, (float, int)) else eval(self.learning_rate[0])(self.n + 1, *(self.learning_rate[1:]))
        gradient = self._calc_gradient()[1]
        if self.regularizer == 'l2':
            self.strategy_hat = self.strategy_space.projection(self.strategy_hat + learning_rate * gradient)
            self.strategy = self.strategy_space.projection(self.strategy_hat + learning_rate * gradient)
        elif self.regularizer == 'entropy':
            self.strategy_hat = self.strategy_hat * np.exp(learning_rate * gradient)
            self.strategy_hat /= np.sum(self.strategy_hat)
            self.strategy = self.strategy_hat * np.exp(learning_rate * gradient)
            self.strategy /= np.sum(self.strategy)
        else:
            raise RuntimeError('Illegal regularizer')
        self.n += 1
        return self.strategy

    def add_gradient(self, gradient):
        self.cum_gradient += gradient
        self.gradient = gradient
