from typing import Any, Dict, Optional, Type, Tuple, Union

import torch as th
from stable_baselines3.common.policies import (ActorCriticPolicy,
                                               get_policy_from_name)
from stable_baselines3.common.type_aliases import GymEnv

from ..policy import MaskPolicy


def ActorCriticPolicyPreConditionsWrapper(
    policy: Union[str, Type[ActorCriticPolicy]]
):

    _policy_class = (
        get_policy_from_name(ActorCriticPolicy, policy)
        if isinstance(policy, str)
        else policy
    )

    if issubclass(_policy_class, MaskPolicy):
        return _policy_class

    class policy_class(MaskPolicy, _policy_class):
        def __init__(
            self,
            *args,
            env: Optional[Union[GymEnv, str]] = None,
            threshold=0.5,
            **kwargs
        ):
            if env is None:
                raise RuntimeError("`env` must be passed to the policy class")
            self.env = env
            self.threshold = threshold

            super().__init__(*args, **kwargs)

        def _get_constructor_parameters(self) -> Dict[str, Any]:
            data = super()._get_constructor_parameters()

            data.update(dict(env=self.env, threshold=self.threshold))
            return data

        def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
            # Preprocess the observation if needed
            features = self.extract_features(obs)
            mask = self._extract_mask(obs)
            latent_pi, latent_vf = self.mlp_extractor(features)
            distribution = self._get_action_dist_from_latent(latent_pi, mask)
            log_prob = distribution.log_prob(actions)
            values = self.value_net(latent_vf)
            return values, log_prob, distribution.entropy()

        def _extract_mask(self, obs: th.Tensor) -> th.Tensor:
            masks = th.ones((obs.shape[0], self.action_dist.action_dim))
            for i, o in enumerate(obs):
                for a in range(self.action_dist.action_dim):
                    masks[i, a] = self.env.is_applicable(o, a)
            return masks > self.threshold

    return policy_class
