import collections
from typing import Type, Optional, Dict, Any, Tuple

import gym
import torch as th
from stable_baselines3.common.distributions import Distribution
from stable_baselines3.common.policies import BasePolicy


class PPOCustomPolicy(BasePolicy):
    def __init__(
        self,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        features_extractor,
        actor,
        critic,
        lr,
        log_std_init: float = 0.0,
        squash_output: bool = False,
        optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
        optimizer_kwargs: Optional[Dict[str, Any]] = None,
        use_wm_optimizer=True,
        v_critic: bool = False,
    ):

        if optimizer_kwargs is None:
            optimizer_kwargs = {}
            # Small values to avoid NaN in Adam optimizer
            if optimizer_class == th.optim.Adam:
                optimizer_kwargs["eps"] = 1e-5

        super(PPOCustomPolicy, self).__init__(
            observation_space,
            action_space,
            features_extractor_class=None,
            features_extractor_kwargs=None,
            optimizer_class=optimizer_class,
            optimizer_kwargs=optimizer_kwargs,
            squash_output=squash_output,
        )

        self.features_extractor = features_extractor
        self.actor = actor
        self.critic = critic
        self.log_std_init = log_std_init
        self.use_wm_optimizer = use_wm_optimizer
        self.v_critic = v_critic
        assert not self.v_critic or not self.use_wm_optimizer, f'Cannot use V-critic and world model together!'
        self.optimizer = self.optimizer_class(self._get_parameters(), lr=lr, **self.optimizer_kwargs)
        if self.use_wm_optimizer:
            self.optimizer_wm = self.optimizer_class(self._get_parameters_wm(), lr=lr, **self.optimizer_kwargs)

    def _get_parameters(self):
        if not self.use_wm_optimizer:
            return self.parameters()

        actor_value_parameters = list(self.actor.parameters())
        for name, parameter in self.critic.named_parameters():
            if name.startswith('value_models'):
                actor_value_parameters.append(parameter)

        return actor_value_parameters

    def _get_parameters_wm(self):
        world_model_parameters = []
        if not self.use_wm_optimizer:
            return world_model_parameters

        for name, parameter in self.critic.named_parameters():
            if not name.startswith('value_models'):
                world_model_parameters.append(parameter)

        return world_model_parameters

    def _get_constructor_parameters(self) -> Dict[str, Any]:
        data = super()._get_constructor_parameters()

        default_none_kwargs = collections.defaultdict(lambda: None)

        data.update(
            dict(
                log_std_init=self.log_std_init,
                squash_output=default_none_kwargs["squash_output"],
                lr_schedule=self._dummy_schedule,  # dummy lr schedule, not needed for loading policy alone
                optimizer_class=self.optimizer_class,
                optimizer_kwargs=self.optimizer_kwargs,
            )
        )
        return data

    def forward(self, embedding: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]:
        distribution = self.actor.proba_distribution(embedding)
        actions = distribution.get_actions(deterministic=deterministic)
        log_prob = distribution.log_prob(actions)
        if self.v_critic:
            values = self.critic(embedding)
        else:
            prob = distribution.distribution.probs
            values = th.sum(prob * self.critic(embedding, self.actor)[0], dim=1)
        return actions, values, log_prob

    def extract_features(self, obs: th.Tensor, prev_slots=None) -> th.Tensor:
        return self.features_extractor(obs.to(self.device) / 255., prev_slots=prev_slots)

    def _predict(self, embedding, deterministic: bool = False) -> th.Tensor:
        return self.actor(embedding, deterministic)

    def evaluate_actions(self, embedding: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor]:
        if self.v_critic:
            values, distribution = self.predict_values(embedding, return_distribution=True)
            averaged_values = values
        else:
            values = self.critic(embedding, actor=self.actor, actions=actions)[0]
            distribution = self.actor.proba_distribution(embedding)
            prob = distribution.distribution.probs
            averaged_values = th.sum(prob * values, dim=1)
        log_prob = distribution.log_prob(actions)
        return values, log_prob, distribution.entropy(), averaged_values

    def get_distribution(self, embedding: th.Tensor) -> Distribution:
        return self.actor.proba_distribution(embedding)

    def predict_values(self, embedding: th.Tensor, return_distribution=False) -> Tuple[th.Tensor, Optional[th.Tensor]]:
        distribution = self.actor.proba_distribution(embedding)
        if self.v_critic:
            values = self.critic(embedding)
        else:
            prob = distribution.distribution.probs
            values = th.sum(prob * self.critic(embedding, self.actor)[0], dim=1)
        if return_distribution:
            return values, distribution

        return (values,)

    def set_training_mode(self, mode: bool) -> None:
        self.actor.set_training_mode(mode)
        self.critic.set_training_mode(mode)
        self.training = mode
