from typing import Any, Dict, Optional, Type, Union, Tuple

import numpy as np
import torch as th
import gym
from sklearn.exceptions import NotFittedError
from stable_baselines3.common.policies import (ActorCriticPolicy,
                                               get_policy_from_name)
from stable_baselines3.common.type_aliases import GymEnv
from stable_baselines3.common.utils import get_schedule_fn

from ..policy import MaskPolicy


def ActorCriticPolicyCombinedWrapper(
    policy: Union[str, Type[ActorCriticPolicy]], is_nn_classifier: bool = False
):

    _policy_class = (
        get_policy_from_name(ActorCriticPolicy, policy)
        if isinstance(policy, str)
        else policy
    )

    if issubclass(_policy_class, MaskPolicy):
        return _policy_class

    class policy_class(MaskPolicy, _policy_class):
        def __init__(
            self,
            *args,
            env: Optional[Union[GymEnv, str]] = None,
            classifier=None,
            use_policy_features=False,
            exploration_rate=0.5,
            threshold=0.5,
            classifier_proba_weight=0.5,
            **kwargs
        ):
            if env is None:
                raise RuntimeError("`env` must be passed to the policy class")
            if classifier is None:
                raise RuntimeError("`classifier` must be provided")
            self.env = env
            self.use_policy_features = use_policy_features
            self.exploration_rate = exploration_rate
            self.threshold = threshold
            self.classifier_proba_weight = classifier_proba_weight

            self._exploration_rate_schedule = get_schedule_fn(self.exploration_rate)
            self._exploration_rate = self._exploration_rate_schedule(1.)
            self._classifier_proba_weight_schedule = get_schedule_fn(self.classifier_proba_weight)
            self._classifier_proba_weight = self._classifier_proba_weight_schedule(1.)

            super().__init__(*args, **kwargs)

            if classifier is None:
                raise RuntimeError("`classifier` must be provided")
            self.classifier = classifier

        def _get_constructor_parameters(self) -> Dict[str, Any]:
            data = super()._get_constructor_parameters()

            data.update(
                dict(
                    env=self.env,
                    classifier=self.classifier,
                    use_policy_features=self.use_policy_features,
                    exploration_rate=self.exploration_rate,
                    threshold=self.threshold,
                    classifier_proba_weight=self.classifier_proba_weight,
                )
            )
            return data

        def update_classifier_proba_weight(self, current_progress_remaining) -> None:
            self._classifier_proba_weight = self._classifier_proba_weight_schedule(current_progress_remaining)

        def update_exploration_rate(self, current_progress_remaining) -> None:
            self._exploration_rate = self._exploration_rate_schedule(current_progress_remaining)

        def forward(
            self, obs: th.Tensor, deterministic: bool = False
        ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
            """
            Forward pass in all the networks (actor and critic)

            :param obs: Observation
            :param deterministic: Whether to sample or use deterministic actions
            :return: action, value and log probability of the action
            """
            # Preprocess the observation if needed
            features = self.extract_features(obs)
            mask = self._extract_mask(obs)
            latent_pi, latent_vf = self.mlp_extractor(features)
            # Evaluate the values for the given observations
            values = self.value_net(latent_vf)
            distribution = self._get_action_dist_from_latent(latent_pi, mask)
            actions = distribution.get_actions(deterministic=deterministic)
            log_prob = distribution.log_prob(actions)
            return actions, values, log_prob, mask

        def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor, masks: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
            # Preprocess the observation if needed
            features = self.extract_features(obs)
            latent_pi, latent_vf = self.mlp_extractor(features)
            distribution = self._get_action_dist_from_latent(latent_pi, masks.bool())
            log_prob = distribution.log_prob(actions)
            values = self.value_net(latent_vf)
            return values, log_prob, distribution.entropy()

        def _extract_mask(self, obs: th.Tensor) -> th.Tensor:
            masks = th.ones((obs.shape[0], self.action_dist.action_dim), device=obs.device)

            if self._exploration_rate > np.random.random():
                classifier_proba_weight = 0.
            else:
                classifier_proba_weight = self._classifier_proba_weight

            for a in range(self.action_dist.action_dim):
                actions = (a * th.ones((obs.shape[0], 1), device=obs.device)).long()
                try:
                    probs = self._predict_proba(obs, actions, classifier_proba_weight)
                    masks[:, a] = probs
                except NotFittedError as e:
                    pass
            
            return masks > self.threshold

        def _predict_proba(self, obs, actions, classifier_proba_weight) -> float:
            """Returns the probability for the action to be valid"""
            classifier_proba = self._predict_classifier_proba(obs, actions)

            # environment
            env_proba = th.ones(obs.shape[0])
            for i, (o, a) in enumerate(zip(obs, actions)):
                env_proba[i] = self.env.is_applicable(o, a.item())

            return (
                classifier_proba_weight * classifier_proba
                + (1.0 - classifier_proba_weight) * env_proba
            )

        def _predict_classifier_proba(self, obs, actions) -> float:
            if isinstance(self.action_space, gym.spaces.Discrete):
                actions = th.nn.functional.one_hot(
                    actions.long(), self.action_space.n
                ).squeeze(1)
            
            with th.no_grad():
                # TODO move obs.reshape in the classifier
                obs = self.extract_features(obs) if self.use_policy_features else obs.reshape(obs.shape[0], -1)
                X = th.concat([actions, obs], dim=1)

            probas = self.classifier.predict_proba(X)
            return th.tensor(probas[:, 1], device=obs.device)  # proba of class `1` = action is valid

    class cnn_policy_class(policy_class):
        
        def _predict_classifier_proba(self, obs, actions) -> float:
            """Returns the probability for the action to be valid"""
            with th.no_grad():
                obs = self.extract_features(obs) if self.use_policy_features else obs.float()
                pred = self.classifier(actions=actions, observations=obs)
            y_test_pred = th.sigmoid(pred)
            return y_test_pred.squeeze(-1)

    return cnn_policy_class if is_nn_classifier else policy_class
