from typing import Any, Dict, List, Optional, Tuple, Type, Union
from gymnasium import spaces
from stable_baselines3.common.torch_layers import (
    BaseFeaturesExtractor,
    FlattenExtractor,
)
from stable_baselines3.common.type_aliases import Schedule
import torch as th
import torch.nn as nn
import numpy as np

from stable_baselines3.common.policies import ActorCriticPolicy
from torch import nn

from config import Config
from popart import PopArtLayer
from impala_cnn import ImpalaCNN


class EAPOActorCritic(ActorCriticPolicy):
    def __init__(
        self,
        config: Config,
        observation_space: spaces.Space,
        action_space: spaces.Space,
        lr_schedule: Schedule,
        net_arch: List[int] | Dict[str, List[int]] | None = None,
        activation_fn: type[nn.Module] = nn.Tanh,
        ortho_init: bool = True,
        use_sde: bool = False,
        log_std_init: float = 0,
        full_std: bool = True,
        use_expln: bool = False,
        squash_output: bool = False,
        features_extractor_class: type[BaseFeaturesExtractor] = ...,
        features_extractor_kwargs: Dict[str, Any] | None = None,
        share_features_extractor: bool = True,
        normalize_images: bool = True,
        optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
        optimizer_kwargs: Dict[str, Any] | None = None,
    ):
        self.config = config

        super().__init__(
            observation_space,
            action_space,
            lr_schedule,
            net_arch,
            activation_fn,
            ortho_init,
            use_sde,
            log_std_init,
            full_std,
            use_expln,
            squash_output,
            features_extractor_class,
            features_extractor_kwargs,
            share_features_extractor,
            normalize_images,
            optimizer_class,
            optimizer_kwargs,
        )

    def _build(self, lr_schedule: Schedule) -> None:
        super()._build(lr_schedule)

        self.value_net = PopArtLayer(
            self.mlp_extractor.latent_dim_vf, 1, self.config.pop_art_beta
        )
        self.entropy_net = PopArtLayer(
            self.mlp_extractor.latent_dim_vf, 1, self.config.pop_art_beta
        )

    def forward(
        self, obs: th.Tensor, deterministic: bool = False
    ) -> Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor]:
        # Preprocess the observation if needed
        features = self.extract_features(obs)
        if self.share_features_extractor:
            latent_pi, latent_vf = self.mlp_extractor(features)
        else:
            pi_features, vf_features = features
            latent_pi = self.mlp_extractor.forward_actor(pi_features)
            latent_vf = self.mlp_extractor.forward_critic(vf_features)
        # Evaluate the values for the given observations
        values = self.value_net(latent_vf)
        entropy_preds = self.entropy_net(latent_vf)
        distribution = self._get_action_dist_from_latent(latent_pi)
        actions = distribution.get_actions(deterministic=deterministic)
        log_prob = distribution.log_prob(actions)
        actions = actions.reshape((-1, *self.action_space.shape))
        return actions, values, log_prob, entropy_preds

    def evaluate_actions(
        self, obs: th.Tensor, actions: th.Tensor
    ) -> Tuple[th.Tensor, th.Tensor, th.Tensor | None, th.Tensor]:
        # Preprocess the observation if needed
        features = self.extract_features(obs)
        if self.share_features_extractor:
            latent_pi, latent_vf = self.mlp_extractor(features)
        else:
            pi_features, vf_features = features
            latent_pi = self.mlp_extractor.forward_actor(pi_features)
            latent_vf = self.mlp_extractor.forward_critic(vf_features)
        distribution = self._get_action_dist_from_latent(latent_pi)
        log_prob = distribution.log_prob(actions)
        values = self.value_net(latent_vf)
        entropy_preds = self.entropy_net(latent_vf)
        entropy = distribution.entropy()
        return values, log_prob, entropy, entropy_preds

    def predict_values(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
        features = super().extract_features(obs, self.vf_features_extractor)
        latent_vf = self.mlp_extractor.forward_critic(features)
        return self.value_net(latent_vf), self.entropy_net(latent_vf)

    def maybe_update_and_normalise_popart(
        self,
        returns: np.ndarray,
        entropy_returns: np.ndarray,
        advantages: np.ndarray,
        entropy_advantages: np.ndarray,
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        if not self.pop_art:
            return returns, entropy_returns, advantages, entropy_advantages

        if self.use_dual_head_net:
            assert isinstance(self.value_net, PopArtLayer)
            stacked_returns = np.stack((returns, entropy_returns), axis=-1)
            stacked_returns = self.value_net.update_and_normalise(stacked_returns)

            normalised_returns, normalised_entropy_returns = (
                stacked_returns[..., 0],
                stacked_returns[..., 1],
            )

        else:
            if isinstance(self.value_net, PopArtLayer):
                normalised_returns = self.value_net.update_and_normalise(returns)
            else:
                normalised_returns = returns

            if isinstance(self.entropy_net, PopArtLayer):
                normalised_entropy_returns = self.entropy_net.update_and_normalise(
                    entropy_returns
                )
            else:
                normalised_entropy_returns = entropy_returns

        return (
            normalised_returns,
            normalised_entropy_returns,
            advantages,
            entropy_advantages,
        )

    def maybe_update_popart(self, returns: np.ndarray, entropy_returns: np.ndarray):
        if isinstance(self.value_net, PopArtLayer):
            if self.use_dual_head_net:
                returns = np.stack((returns, entropy_returns), axis=-1)
            self.value_net.update_stats_and_params(returns)
        if isinstance(self.entropy_net, PopArtLayer):
            self.entropy_net.update_stats_and_params(entropy_returns)

    def maybe_popart_denormalise(
        self,
        predicted_returns: th.Tensor,
        predicted_entropy_returns: Optional[th.Tensor],
    ) -> Tuple[th.Tensor, Optional[th.Tensor]]:
        if isinstance(self.value_net, PopArtLayer):
            if self.use_dual_head_net:
                assert predicted_entropy_returns is not None
                returns = th.stack(
                    (predicted_returns, predicted_entropy_returns), dim=-1
                )
                returns = self.value_net.denormalise(returns)
                predicted_returns, predicted_entropy_returns = (
                    returns[:, 0],
                    returns[:, 1],
                )
            else:
                predicted_returns = self.value_net.denormalise(predicted_returns)
        if predicted_entropy_returns is not None and isinstance(
            self.entropy_net, PopArtLayer
        ):
            predicted_entropy_returns = self.entropy_net.denormalise(
                predicted_entropy_returns
            )
        return predicted_returns, predicted_entropy_returns
