import abc
import math
from typing import Any, Dict, Tuple

import torch as th
from stable_baselines3.common.distributions import (
    BernoulliDistribution,
    CategoricalDistribution,
    DiagGaussianDistribution,
    Distribution,
    MultiCategoricalDistribution,
    StateDependentNoiseDistribution,
)

class MaskPolicy(abc.ABC):
    
    @abc.abstractmethod
    def _get_constructor_parameters(self) -> Dict[str, Any]:
        raise NotImplementedError("_get_constructor_parameters")

    @abc.abstractmethod
    def _extract_mask(self, obs: th.Tensor) -> th.Tensor:
        raise NotImplementedError("_extract_mask")

    def forward(
        self, obs: th.Tensor, deterministic: bool = False
    ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
        """
        Forward pass in all the networks (actor and critic)

        :param obs: Observation
        :param deterministic: Whether to sample or use deterministic actions
        :return: action, value and log probability of the action
        """
        # Preprocess the observation if needed
        features = self.extract_features(obs)
        mask = self._extract_mask(obs)
        latent_pi, latent_vf = self.mlp_extractor(features)
        # Evaluate the values for the given observations
        values = self.value_net(latent_vf)
        distribution = self._get_action_dist_from_latent(latent_pi, mask)
        actions = distribution.get_actions(deterministic=deterministic)
        log_prob = distribution.log_prob(actions)
        return actions, values, log_prob

    def _get_action_dist_from_latent(
        self, latent_pi: th.Tensor, mask: th.Tensor = None
    ) -> Distribution:
        """
        Retrieve action distribution given the latent codes.

        :param latent_pi: Latent code for the actor
        :return: Action distribution
        """
        if mask is None:
            mask = th.ones((latent_pi.shape[0], self.action_dist.action_dim), device=latent_pi.device).bool()
        mean_actions = self.action_net(latent_pi)
        mean_actions -= (mask.bitwise_not() * math.inf).nan_to_num(0.0)

        if isinstance(self.action_dist, DiagGaussianDistribution):
            raise NotImplementedError("DiagGaussianDistribution")
            # return self.action_dist.proba_distribution(
            #     mean_actions, self.log_std
            # )
        elif isinstance(self.action_dist, CategoricalDistribution):
            # Here mean_actions are the logits before the softmax
            return self.action_dist.proba_distribution(
                action_logits=mean_actions
            )
        elif isinstance(self.action_dist, MultiCategoricalDistribution):
            # Here mean_actions are the flattened logits
            raise NotImplementedError("MultiCategoricalDistribution")
            # return self.action_dist.proba_distribution(
            #     action_logits=mean_actions
            # )
        elif isinstance(self.action_dist, BernoulliDistribution):
            # Here mean_actions are the logits (before rounding to get the binary actions)
            raise NotImplementedError("BernoulliDistribution")
            # return self.action_dist.proba_distribution(
            #     action_logits=mean_actions
            # )
        elif isinstance(self.action_dist, StateDependentNoiseDistribution):
            raise NotImplementedError("StateDependentNoiseDistribution")
            # return self.action_dist.proba_distribution(
            #     mean_actions, self.log_std, latent_pi
            # )
        else:
            raise ValueError("Inapplicable action distribution")


    def get_distribution(self, obs: th.Tensor) -> Distribution:
        """
        Get the current policy distribution given the observations.

        :param obs:
        :return: the action distribution.
        """
        mask = self._extract_mask(obs)
        features = self.extract_features(obs)
        latent_pi = self.mlp_extractor.forward_actor(features)
        return self._get_action_dist_from_latent(latent_pi, mask)