import numpy as np
import interface


class MD(object):
    def __init__(self, strategy_space, regularizer, learning_rate, **kwargs):
        assert not (regularizer == 'entropy' and not isinstance(strategy_space, interface.strategy.ProbabilitySimplexStrategySpace))
        self.strategy_space = strategy_space
        self.strategy = strategy_space.random_strategy() if kwargs['random_init'] else strategy_space.uniform_strategy()
        num_actions = strategy_space.num_actions
        self.regularizer = regularizer
        self.cum_gradient = np.zeros(num_actions)
        self.gradient = np.zeros(num_actions)
        self.learning_rate = learning_rate
        self.n = 0

    def name(self):
        alg_name = self.__class__.__name__
        alg_name += '_lr{}'.format(self.learning_rate)
        alg_name += '_{}'.format(self.regularizer)
        return alg_name

    def _md(self, cum_gradient, gradient):
        learning_rate = self.learning_rate if isinstance(self.learning_rate, (float, int)) else eval(self.learning_rate[0])(self.n + 1, *(self.learning_rate[1:]))
        if self.regularizer == 'l2':
            self.strategy = self.strategy_space.projection(self.strategy + learning_rate * gradient)
        elif self.regularizer == 'entropy':
            self.strategy = self.strategy * np.exp(learning_rate * gradient)
            self.strategy /= np.sum(self.strategy)
        else:
            raise RuntimeError('Illegal regularizer')

    def _calc_gradient(self):
        return self.cum_gradient, self.gradient

    def calc_strategy(self):
        self._md(*self._calc_gradient())
        self.n += 1
        return self.strategy

    def add_gradient(self, gradient):
        self.cum_gradient += gradient
        self.gradient = gradient
