import numpy as np

class RegretMatchingPolicy:
    """
    A policy that does improvement and evaluation using
    the regret matching setup, instead of the typical
    greedy improvement step.

    Used to stabilize training a little more when doing policy
    iteration.
    """
    def __init__(
            self,
            num_states: int,
            num_actions: int,
    ):
        """
        :param num_states: The number of states in the environment.
        :param num_actions: The number of actions in the envrionment.
        :attr sum: The array keeping track of the current sums of the
        advantage values. Used to calculate the probability
        distribution via regret matching.
        :attr pi: A matrix representing the policy.
        """
        self.sum = np.zeros((num_states, num_actions))
        self.pi = np.full(self.sum.shape, 1/self.sum.shape[1])

    def __getitem__(self, k: int) -> np.array:
        """
        Used to give this class a similar flavor as a normal
        tabular policy.
        """
        return self.pi[k]

    def __len__(self) -> int:
        """
        Used to give this class a similar flavor as a normal
        tabular policy.
        """
        return len(self.pi)

    def policy_improvement(self, q: np.array) -> None:
        """
        Updates the sum with new advantage values, and
        computes the new policy.

        The strange transposes are to comply with
        numpy broadcasting rules.

        :param q_values: The q_values to update with.
        q_values be from policy evaluation done on the current
        policy. It's an array where the first dimension refers
        to the state index, and the second dimension refers to
        the action.
        """
        self.sum += (q.transpose() - np.sum(self.pi * q, axis=1)).transpose()
        pos_adv_sum = np.maximum(self.sum, 0)
        denominator = np.sum(pos_adv_sum, axis=1)
        self.pi = np.nan_to_num(
            (pos_adv_sum.transpose() / denominator).transpose(),
            nan=1/self.pi.shape[1])
        return self.pi

class EpsGreedyPolicy(RegretMatchingPolicy):
    """
    An epsilon-greedy policy. Used mainly for debugging.
    """
    def __init__(self,
                 num_states: int,
                 num_actions: int,
                 epsilon: float = 0.05,
                 ):
        super().__init__(num_states, num_actions)
        self.epsilon = epsilon

    def policy_improvement(self, q: np.array) -> None:
        best_actions = np.argmax(q, axis=1)
        self.pi = np.full(
            self.pi.shape,
            self.epsilon / self.pi.shape[1])
        for i, best_action in enumerate(best_actions):
            self.pi[i, best_action] += 1 - self.epsilon
