from typing import Optional, Tuple

import torch as th
from stable_baselines3.common.distributions import Distribution
from torch import nn
from torch.distributions import Normal, Gumbel


class CategoricalDistribution(Distribution):
    def __init__(self, action_dim: int, temperature: float = 0.5):
        super(CategoricalDistribution, self).__init__()
        self.action_dim = action_dim
        self.temperature = temperature
        self.gumbel = Gumbel(loc=th.zeros(self.action_dim), scale=th.ones(self.action_dim))
        self._sample = None
        self._log_prob = None
        self._prob = None
        self._mode = None
        self._sample_one_hot = None

    def proba_distribution_net(self, latent_dim: int) -> nn.Module:
        action_logits = nn.Linear(latent_dim, self.action_dim)
        return action_logits

    def proba_distribution(self, logits):
        self._mode = logits.max(dim=-1)[1]

        gumbels = (
            -th.empty_like(logits, memory_format=th.legacy_contiguous_format).exponential_().log()
        )  # ~Gumbel(0,1)
        gumbels = (logits + gumbels) / self.temperature
        self._log_prob = gumbels - gumbels.logsumexp(dim=-1, keepdim=True)
        self._prob = self._log_prob.softmax(dim=-1)

        sample = self._prob.max(dim=-1, keepdim=True)[1]
        sample_one_hot = th.zeros_like(logits, memory_format=th.legacy_contiguous_format).scatter_(-1, sample, 1.0)
        self._sample_one_hot = sample_one_hot - self._prob.detach() + self._prob
        self._sample = sample.squeeze(dim=1)

    def log_prob(self, actions: th.Tensor) -> th.Tensor:
        return self._log_prob.gather(dim=1, index=actions)

    def entropy(self) -> Optional[th.Tensor]:
        return -th.sum(self._prob * self._log_prob, dim=-1)

    def sample(self) -> th.Tensor:
        return self._sample

    def mode(self) -> th.Tensor:
        return self._mode

    def actions_from_params(self, **kwargs) -> th.Tensor:
        self.proba_distribution(kwargs['mean_actions'])
        if 'deterministic' in kwargs and kwargs['deterministic'] is True:
            return self._mode
        return self._sample

    def log_prob_from_params(self, **kwargs) -> Tuple[th.Tensor, th.Tensor]:
        sample = self.actions_from_params(**kwargs).unsqueeze(dim=-1)
        return self._sample_one_hot, self.log_prob(sample)
