import numpy as np
import torch
import torch.nn.functional as F
from torch.distributions import Categorical
from k_level_policy_gradients.src.policy.policy import ParametricPolicy
from k_level_policy_gradients.src.utils.parameters import to_parameter


class GumbelSoftmaxPolicy(ParametricPolicy):
    def __init__(self, epsilon, tau):
        """
        Constructor.

        Args:
            approximator (Regressor): the regressor outputting the action logits w.r.t. the
                state;
            epsilon ([float, Parameter]): the exploration coefficient. It indicates
                    the probability of performing a random actions in the current
                    step.
        """
        super().__init__()

        self._epsilon = to_parameter(epsilon)
        self._gumbel = F.gumbel_softmax
        self._tau = tau
        self._mode = "train"  # options are random, train, test

        self._add_save_attr(_epsilon="mushroom", _gumbel="torch", _tau="primitive")

    def draw_action(self, state, action_mask=None):
        if self._mode == "random":
            return np.array([np.random.choice(np.where(action_mask)[0])])
        logits = self._approximator.predict(state, output_tensor=True)
        logits_mask = torch.where(
            torch.tensor(action_mask, dtype=torch.bool),
            logits.squeeze(),
            torch.tensor(float("-inf")),
        )
        action_probs = F.softmax(logits_mask, dim=-1)
        if self._mode == "test":
            return np.array([torch.max(action_probs)[1]])
        elif self._mode == "train":
            if np.random.uniform() < self._epsilon():
                return np.array([np.random.choice(np.where(action_mask)[0])])
            else:
                action = Categorical(action_probs).sample()
                return action

        return np.array([np.random.choice(np.where(action_mask)[0])])

    def draw_action_hidden(self, state, hidden_state, action_mask=None):
        logits, h = self._approximator.predict(
            state,
            hidden_state,
            output_hidden=True,
            output_tensor=True,
        )
        logits_mask = torch.where(
            torch.tensor(action_mask, dtype=torch.bool),
            logits.squeeze(),
            torch.tensor(float("-inf")),
        )
        if self._mode == "random":
            return np.array([np.random.choice(np.where(action_mask)[0])]), h
        else:
            action_probs = F.softmax(logits_mask, dim=-1)
            if self._mode == "test":
                return np.array([torch.argmax(action_probs)]), h
            elif self._mode == "train":
                if not np.random.uniform() < self._epsilon():
                    action = np.array([Categorical(action_probs).sample()])
                    return action, h
                else:
                    return np.array([np.random.choice(np.where(action_mask)[0])]), h
