import math
from typing import Tuple, Type, Union

import torch as th
from stable_baselines3.common.distributions import (
    BernoulliDistribution, CategoricalDistribution, DiagGaussianDistribution,
    Distribution, MultiCategoricalDistribution,
    StateDependentNoiseDistribution)
from stable_baselines3.common.policies import (ActorCriticPolicy,
                                               get_policy_from_name)

from ...utils.features_extractor import ResizeFeatureExtractors


def ActorCriticPolicyWarmStartWrapper(
    policy: Union[str, Type[ActorCriticPolicy]]
):

    _policy_class = (
        get_policy_from_name(ActorCriticPolicy, policy)
        if isinstance(policy, str)
        else policy
    )


    class policy_class(_policy_class):

        def load_policy_reuse(self, policy):
            self.features_dim = policy.features_extractor.features_dim
            self.features_extractor = ResizeFeatureExtractors(policy.features_extractor)

            self.action_net = policy.action_net
            self.value_net = policy.value_net

            self._rebalance_distribution_function = self._default_rebalance_distribution_function

            ## !!!! Specific code: we assume that the loaded action space is simply an extension of the one we train for
            ## TODO make more generic
            if self.action_dist.action_dim < policy.action_dist.action_dim:
                self.mask = th.tensor([float(i < self.action_dist.action_dim) for i in range(policy.action_dist.action_dim)])
                self.action_dist = policy.action_dist
            elif self.action_dist.action_dim > policy.action_dist.action_dim:
                self._rebalance_distribution_function = self._rebalance_distribution_function_builder(self.action_dist.action_dim - policy.action_dist.action_dim)
                self.mask = th.tensor([float(i < policy.action_dist.action_dim) for i in range(self.action_dist.action_dim)])
            else:
                self.mask = th.ones(policy.action_dist.action_dim)

        @staticmethod
        def _default_rebalance_distribution_function(mean_actions, mask):
            return mean_actions - (mask.bitwise_not() * math.inf).nan_to_num(0.0)

        @staticmethod
        def _rebalance_distribution_function_builder(num_unknown_actions):

            def _func(mean_actions, _mask):
                # fill missing values with mean values
                return th.concat([
                    mean_actions, 
                    mean_actions.mean(axis=1).repeat(num_unknown_actions, 1).T
                ], axis=1)

            return _func

        def _extract_mask(self, obs: th.Tensor) -> th.Tensor:
            return self.mask.repeat(obs.shape[0], 1).to(obs.device).bool()

        def forward(
            self, obs: th.Tensor, deterministic: bool = False
        ) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
            """
            Forward pass in all the networks (actor and critic)

            :param obs: Observation
            :param deterministic: Whether to sample or use deterministic actions
            :return: action, value and log probability of the action
            """
            # Preprocess the observation if needed
            features = self.extract_features(obs)
            mask = self._extract_mask(obs)
            latent_pi, latent_vf = self.mlp_extractor(features)
            # Evaluate the values for the given observations
            values = self.value_net(latent_vf)
            distribution = self._get_action_dist_from_latent(latent_pi, mask)
            actions = distribution.get_actions(deterministic=deterministic)
            log_prob = distribution.log_prob(actions)
            return actions, values, log_prob

        def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
            # Preprocess the observation if needed
            features = self.extract_features(obs)
            masks = self._extract_mask(obs)
            latent_pi, latent_vf = self.mlp_extractor(features)
            distribution = self._get_action_dist_from_latent(latent_pi, masks)
            log_prob = distribution.log_prob(actions)
            values = self.value_net(latent_vf)
            return values, log_prob, distribution.entropy()

        def _get_action_dist_from_latent(
            self, latent_pi: th.Tensor, mask: th.Tensor = None
        ) -> Distribution:
            """
            Retrieve action distribution given the latent codes.

            :param latent_pi: Latent code for the actor
            :return: Action distribution
            """
            if mask is None:
                mask = th.ones((latent_pi.shape[0], self.action_dist.action_dim), device=latent_pi.device).bool()
            mean_actions = self.action_net(latent_pi)

            mean_actions = self._rebalance_distribution_function(mean_actions, mask)

            if isinstance(self.action_dist, DiagGaussianDistribution):
                raise NotImplementedError("DiagGaussianDistribution")
            elif isinstance(self.action_dist, CategoricalDistribution):
                # Here mean_actions are the logits before the softmax
                return self.action_dist.proba_distribution(
                    action_logits=mean_actions
                )
            elif isinstance(self.action_dist, MultiCategoricalDistribution):
                # Here mean_actions are the flattened logits
                raise NotImplementedError("MultiCategoricalDistribution")
            elif isinstance(self.action_dist, BernoulliDistribution):
                # Here mean_actions are the logits (before rounding to get the binary actions)
                raise NotImplementedError("BernoulliDistribution")
            elif isinstance(self.action_dist, StateDependentNoiseDistribution):
                raise NotImplementedError("StateDependentNoiseDistribution")
            else:
                raise ValueError("Inapplicable action distribution")


        def get_distribution(self, obs: th.Tensor) -> Distribution:
            """
            Get the current policy distribution given the observations.

            :param obs:
            :return: the action distribution.
            """
            mask = self._extract_mask(obs)
            features = self.extract_features(obs)
            latent_pi = self.mlp_extractor.forward_actor(features)
            return self._get_action_dist_from_latent(latent_pi, mask)

    return policy_class


