import numpy as np
from modules.utils.Log import Logger
from modules.agents.Probability import ProbabilityEmpiricalMeasure
from modules.agents.Policy import Policy
from modules.agents.tabular.DiscreteAgent import DiscreteAgent
import modules.agents.DiscreteOTTRPOFunctions as WPO

# OTTRPO
class DiscreteOTTRPOAgent(DiscreteAgent):
    def __init__(self, policy: Policy or None, parameters: dict) -> None:
        DiscreteAgent.__init__(self, policy, parameters)

        assert parameters["p"] >= 1, 'p should be larger than 1'
        self._p = parameters["p"]

        assert callable(parameters["dist"]), "dist must be callable"
        self._dist = parameters["dist"]

        self._name = 'OTTRPO'

        self._distances = None
        self._distances_p = None
        self._init_distances()

        # optimization
        self.lambda_star = 10  # initial guess

    def p(self) -> float:
        return self._p

    def save(self) -> dict:
        data = super().save()
        data["parameters"]["p"] = self.p()
        return data

    def update_policy(self, advantage, rho: ProbabilityEmpiricalMeasure, logger: Logger) -> None:
        # if np.abs(advantage).max() < 1e-2:
        #     return
        # normalize per state
        # advantage = np.divide(advantage - np.mean(advantage, axis=1).reshape((self.n_state(), 1)),
        #                      np.repeat(np.std(advantage, axis=1).reshape((self.n_state(), 1)), self.n_action(),axis=1) + 1e-3)
        # normalize overall
        # advantage = (advantage - advantage.mean())/(advantage.std() + 1e-3)
        # solve dual problem to find multiplier and minimizers of regularized function
        lambda_star, mins = WPO.solve_dual(policy=self.policy(copy=False),
                                           rho=rho,
                                           advantage=advantage,
                                           epsilon_p=self._epsilon_p(),
                                           distances_p=self.distances_p,
                                           state_space=self.state_space(copy=False),
                                           action_space=self.action_space(copy=False),
                                           logger=logger)
        # use dual solution (via minimizers) to compute the new optimal policy
        new_policy = WPO.solve_primal(policy=self.policy(copy=False),
                                      rho=rho,
                                      mins=mins,
                                      epsilon_p=self._epsilon_p(),
                                      distances_p=self.distances_p,
                                      state_space=self.state_space(copy=False),
                                      action_space=self.action_space(copy=False),
                                      logger=logger)
        self.lambda_star = lambda_star
        self.set_trustregion(self.epsilon()/1.0)
        for s in self.state_space(copy=False):
            self.policy(copy=False).set_policy(s, new_policy(s))

    def distances(self, state: int):
        return self._distances
    
    def distances_p(self, state: int):
        return self._distances_p

    def _epsilon_p(self) -> float:
        return np.power(self.epsilon(), self.p())

    def _dist_p(self, a, b) -> float:
        return np.power(self._dist(a, b), self.p())

    def _init_distances(self) -> None:
        self._distances = WPO.compute_distances(self._dist, self.action_space(copy=False))
        self._distances_p = WPO.compute_distances_p(self._dist, self.action_space(copy=False), self.p())