import numpy as np

from k_level_policy_gradients.src.policy.policy import ParametricPolicy


class GaussianPolicy(ParametricPolicy):
    """
    This policy is commonly used in the Deep Deterministic Policy Gradient
    algorithm.
    """

    def __init__(self, sigma, action_space):
        """
        Constructor.

        Args:
            sigma (np.ndarray): sigma of the Gaussian noise
            shape (tuple): shape of the action space
        """
        super().__init__()

        self._sigma = sigma
        self._action_space = action_space
        self._mode = "train"  # options are random, train, test

        self._add_save_attr(
            _sigma="primitive",
            _action_space="mushroom",
            _mode="primitive",
        )

    def draw_action(self, state):
        if self._mode == "random":
            return self._action_space.sample()
        else:
            mu = self._approximator.predict(state).squeeze()
            if self._mode == "test":
                return mu
            elif self._mode == "train":
                noise = np.random.normal(
                    scale=self._sigma, size=self._action_space.shape
                )
                return mu + noise
            else:
                raise ValueError("Invalid policy mode given")
