
import copy
import gym
import numpy as np
import torch as th
from torch import nn
import collections
from abc import ABC, abstractmethod
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union

from stable_baselines3.common.type_aliases import Schedule
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, MlpExtractor
from stable_baselines3.common.preprocessing import get_flattened_obs_dim

from stable_baselines3.common.distributions import (
    BernoulliDistribution,
    CategoricalDistribution,
    DiagGaussianDistribution,
    Distribution,
    MultiCategoricalDistribution,
    StateDependentNoiseDistribution,
    make_proba_distribution,)

from delphicORL.networks.transformer import TransformerEnc
    
class IdentityExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.Space):
        super().__init__(observation_space, get_flattened_obs_dim(observation_space))

    def forward(self, observations: th.Tensor) -> th.Tensor:
        #print(observations.shape)
        #if isinstance(self._observation_space, gym.spaces.Discrete):
            #with th.no_grad():
                #observations = nn.functional.one_hot(observations.long(), num_classes=self.features_dim).float()
        return observations


class RecurrentActorCriticPolicy(ActorCriticPolicy):
    def __init__(
        self,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        lr_schedule = None,
        net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = [32, 32],
        lstm_hidden_size = 32,
        lstm_num_layers = 2,
        activation_fn: Type[nn.Module] = nn.Tanh,
        ortho_init:bool = True,
        use_sde: bool = False,
        log_std_init: float = 0.0,
        full_std: bool = True,
        use_expln: bool = False,
        squash_output: bool = False,
        features_extractor_class: Type[BaseFeaturesExtractor] = IdentityExtractor,
        features_extractor_kwargs: Optional[Dict[str, Any]] = None,
        normalize_images: bool = True,
        optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
        optimizer_kwargs: Optional[Dict[str, Any]] = None,
    ):

        self.lstm_hidden_size = lstm_hidden_size
        self.lstm_num_layers = lstm_num_layers

        super().__init__(
            observation_space,
            action_space,
            lr_schedule=lambda _: th.finfo(th.float32).max,
            net_arch=net_arch,
            activation_fn=activation_fn,
            ortho_init=ortho_init,
            use_sde=use_sde,
            log_std_init=log_std_init,
            full_std=full_std,
            use_expln=use_expln,
            features_extractor_class = features_extractor_class,
            features_extractor_kwargs=features_extractor_kwargs,
            normalize_images=normalize_images,
            optimizer_class=optimizer_class,
            optimizer_kwargs=optimizer_kwargs,
            squash_output=squash_output,
        )



    def _get_constructor_parameters(self):
        data = super()._get_constructor_parameters()
        data.update(
            dict(
                lstm_hidden_size=self.lstm_hidden_size,
                lstm_num_layers=self.lstm_num_layers,
            )
        )
        return data

    def _build_mlp_extractor(self) -> None:
        self.lstm = nn.LSTM(self.features_dim, self.lstm_hidden_size,
                             num_layers=self.lstm_num_layers,
                             batch_first =True).to(self.device)

        self.mlp_extractor = MlpExtractor(
            self.lstm_hidden_size,
            net_arch=self.net_arch,
            activation_fn=self.activation_fn,
            device=self.device,
        )

    def _build(self, lr_schedule: Schedule) -> None:
        """
        Create the networks and the optimizer.
        :param lr_schedule: Learning rate schedule
            lr_schedule(1) is the initial learning rate
        """
        self._build_mlp_extractor()

        latent_dim_pi = self.mlp_extractor.latent_dim_pi

        if isinstance(self.action_dist, DiagGaussianDistribution):
            self.action_net, self.log_std = self.action_dist.proba_distribution_net(
                latent_dim=latent_dim_pi, log_std_init=self.log_std_init
            )
        elif isinstance(self.action_dist, StateDependentNoiseDistribution):
            self.action_net, self.log_std = self.action_dist.proba_distribution_net(
                latent_dim=latent_dim_pi, latent_sde_dim=latent_dim_pi, log_std_init=self.log_std_init
            )
        elif isinstance(self.action_dist, (CategoricalDistribution, MultiCategoricalDistribution, BernoulliDistribution)):
            self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
        else:
            raise NotImplementedError(f"Unsupported distribution '{self.action_dist}'.")
        self.value_net = nn.Linear(self.mlp_extractor.latent_dim_vf, 1)
        # Init weights: use orthogonal initialization
        # with small initial weight for the output
        if self.ortho_init:
            # TODO: check for features_extractor
            # Values from stable-baselines.
            # features_extractor/mlp values are
            # originally from openai/baselines (default gains/init_scales).
            module_gains = {
                self.lstm: np.sqrt(2),
                self.mlp_extractor: np.sqrt(2),
                self.action_net: 0.01,
                self.value_net: 1,
            }
            for module, gain in module_gains.items():
                module.apply(partial(self.init_weights, gain=gain))

        # Setup optimizer with initial learning rate
        self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)

    def forward(self, obs: th.Tensor, deterministic: bool = False,
            lstm_states = None, return_states = False):
        """
        Forward pass in all the networks (actor and critic)
        :param obs: Observation
        :param deterministic: Whether to sample or use deterministic actions
        :return: action, value and log probability of the action
        """
        # Preprocess the observation if needed
        features = self.extract_features(obs)
        bs, seqlen, obs_dim = obs.shape
        features, lstm_states = self.lstm(features, lstm_states)
        latent_pi, latent_vf = self.mlp_extractor(features)
        # Evaluate the values for the given observations
        values = self.value_net(latent_vf)
        distribution = self._get_action_dist_from_latent(latent_pi)
        actions = distribution.get_actions(deterministic=deterministic)
        actions = actions.reshape(bs, seqlen, -1)
        log_prob = distribution.log_prob(actions)
        #actions = actions.reshape((-1, seqlen,) + self.action_space.shape)

        if return_states:
            return actions, values, log_prob, lstm_states
        return actions, values, log_prob

    def _get_action_dist_from_latent(self, latent_pi):
        """
        Retrieve action distribution given the latent codes.
        :param latent_pi: Latent code for the actor
        :return: Action distribution
        """
        mean_actions = self.action_net(latent_pi)
        if isinstance(self.action_dist, DiagGaussianDistribution):
            return self.action_dist.proba_distribution(mean_actions, self.log_std)
        elif isinstance(self.action_dist, CategoricalDistribution):
            # Here mean_actions are the logits before the softmax
            return self.action_dist.proba_distribution(action_logits=mean_actions)
        elif isinstance(self.action_dist, MultiCategoricalDistribution):
            # Here mean_actions are the flattened logits
            return self.action_dist.proba_distribution(action_logits=mean_actions)
        elif isinstance(self.action_dist, BernoulliDistribution):
            # Here mean_actions are the logits (before rounding to get the binary actions)
            return self.action_dist.proba_distribution(action_logits=mean_actions)
        elif isinstance(self.action_dist, StateDependentNoiseDistribution):
            return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_pi)
        else:
            raise ValueError("Invalid action distribution")

    def predict(
        self,
        observation: Union[np.ndarray, Dict[str, np.ndarray]],
        lstm_states: Optional[Tuple[np.ndarray, ...]] = None,
        deterministic: bool = False,
        return_states = True
    ):

        self.set_training_mode(False)

        vectorized_env = len(observation.shape) == 3
        observation = th.Tensor(observation).to(dtype=th.float32, device=self.device)

        with th.no_grad():
            actions, lstm_states = self._predict(observation, lstm_states=lstm_states,
                        deterministic=deterministic, return_states=True)
        # Convert to numpy, and reshape to the original action shape
        if len(observation.shape) == 3:
            actions = actions.cpu().numpy().reshape((observation.shape[:2]) + self.action_space.shape)
        elif isinstance(self.action_space, gym.spaces.Discrete):
            actions = actions.cpu().numpy().reshape(observation.shape[:2])
        else:
            actions = actions.cpu().numpy().reshape((-1,) + self.action_space.shape)

        if isinstance(self.action_space, gym.spaces.Box):
            if self.squash_output:
                # Rescale to proper domain when using squashing
                actions = self.unscale_action(actions)
            else:
                # Actions could be on arbitrary scale, so clip the actions to avoid
                # out of bound error (e.g. if sampling from a Gaussian distribution)
                actions = np.clip(actions, self.action_space.low, self.action_space.high)

        # Remove batch dimension if needed
        #if not vectorized_env:
            #actions = actions.squeeze(axis=0)
            #lstm_states = lstm_states.squeeze(axis=0)

        if return_states:
            return actions, lstm_states
        else:
            return actions

    def _predict(self, observation: th.Tensor, deterministic: bool = False, lstm_states = None,
                 return_states=False):
        """
        Get the action according to the policy for a given observation.
        :param observation:
        :param deterministic: Whether to use stochastic or deterministic actions
        :return: Taken action according to the policy
        """
        dist = self.get_distribution(observation, lstm_states, return_states)
        if return_states:
            dist, lstm_states = dist
            return dist.get_actions(deterministic=deterministic), lstm_states
        return dist.get_actions(deterministic=deterministic)

    def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor, masks=None,
                    lstm_states = None,
                    return_states=False):
        """
        Evaluate actions according to the current policy,
        given the observations.
        :param obs:
        :param actions:
        :return: estimated value, log likelihood of taking those actions
            and entropy of the action distribution.
        """
        features = self.extract_features(obs)
        features, lstm_states = self.lstm(features, lstm_states)
        latent_pi, latent_vf = self.mlp_extractor(features)
        
        if masks is not None:
            distribution = self._get_action_dist_from_latent(latent_pi[masks])
            actions = actions[masks]
        else:
            distribution = self._get_action_dist_from_latent(latent_pi)
        log_prob = distribution.log_prob(actions)
        values = self.value_net(latent_vf)
        if masks is not None:
            values = values[masks]
        entropy = distribution.entropy()
        if return_states:
            return values, log_prob, entropy, lstm_states    
        return values, log_prob, entropy

    def get_distribution(self, obs: th.Tensor, lstm_states = None, return_states=False):
        """
        Get the current policy distribution given the observations.
        :param obs:
        :return: the action distribution.
        """
        features = self.extract_features(obs)
        features, lstm_states = self.lstm(features, lstm_states)
        latent_pi = self.mlp_extractor.forward_actor(features)
        latent_pi = latent_pi.reshape(-1, latent_pi.shape[-1])
        if return_states:
            return self._get_action_dist_from_latent(latent_pi), lstm_states    
        return self._get_action_dist_from_latent(latent_pi)

    def predict_values(self, obs: th.Tensor, lstm_states = None, return_states=False):
        """
        Get the estimated values according to the current policy given the observations.
        :param obs:
        :return: the estimated values.
        """
        features = self.extract_features(obs, lstm_states)
        latent_vf = self.mlp_extractor.forward_critic(features)
        if return_states:
            return self.value_net(latent_vf), lstm_states  
        return self.value_net(latent_vf)


class GRUActorCriticPolicy(RecurrentActorCriticPolicy):
    def __init__(
        self,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        lr_schedule = None,
        net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = [32, 32],
        lstm_hidden_size = 256,
        lstm_num_layers = 2,
        activation_fn: Type[nn.Module] = nn.Tanh,
        ortho_init:bool = True,
        use_sde: bool = False,
        log_std_init: float = 0.0,
        full_std: bool = True,
        use_expln: bool = False,
        squash_output: bool = False,
        features_extractor_class: Type[BaseFeaturesExtractor] = IdentityExtractor,
        features_extractor_kwargs: Optional[Dict[str, Any]] = None,
        normalize_images: bool = True,
        optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
        optimizer_kwargs: Optional[Dict[str, Any]] = None,
    ):
        super().__init__(
            observation_space,
            action_space,
            lr_schedule=lambda _: th.finfo(th.float32).max,
            net_arch=net_arch,
            lstm_hidden_size=lstm_hidden_size,
            lstm_num_layers=lstm_num_layers,
            activation_fn=activation_fn,
            ortho_init=ortho_init,
            use_sde=use_sde,
            log_std_init=log_std_init,
            full_std=full_std,
            use_expln=use_expln,
            features_extractor_class = features_extractor_class,
            features_extractor_kwargs=features_extractor_kwargs,
            normalize_images=normalize_images,
            optimizer_class=optimizer_class,
            optimizer_kwargs=optimizer_kwargs,
            squash_output=squash_output,
        )

    def _build_mlp_extractor(self) -> None:
        self.lstm = nn.GRU(self.features_dim, self.lstm_hidden_size,
                             num_layers=self.lstm_num_layers,
                             batch_first =True).to(self.device)

        self.mlp_extractor = MlpExtractor(
            self.lstm_hidden_size,
            net_arch=self.net_arch,
            activation_fn=self.activation_fn,
            device=self.device,
        )



class TransformerActorCriticPolicy(RecurrentActorCriticPolicy):
    def __init__(
        self,
        observation_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        lr_schedule = None,
        net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = [32, 32],
        transformer_emb_size = 256,
        transformer_num_layers = 2,
        max_history_len = 10,
        activation_fn: Type[nn.Module] = nn.Tanh,
        ortho_init:bool = True,
        use_sde: bool = False,
        log_std_init: float = 0.0,
        full_std: bool = True,
        use_expln: bool = False,
        squash_output: bool = False,
        features_extractor_class: Type[BaseFeaturesExtractor] = IdentityExtractor,
        features_extractor_kwargs: Optional[Dict[str, Any]] = None,
        normalize_images: bool = True,
        optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam,
        optimizer_kwargs: Optional[Dict[str, Any]] = None,
    ):

        self.transformer_emb_size = transformer_emb_size
        self.transformer_num_layers = transformer_num_layers

        super().__init__(
            observation_space,
            action_space,
            lr_schedule=lambda _: th.finfo(th.float32).max,
            net_arch=net_arch,
            activation_fn=activation_fn,
            ortho_init=ortho_init,
            use_sde=use_sde,
            log_std_init=log_std_init,
            full_std=full_std,
            use_expln=use_expln,
            features_extractor_class = features_extractor_class,
            features_extractor_kwargs=features_extractor_kwargs,
            normalize_images=normalize_images,
            optimizer_class=optimizer_class,
            optimizer_kwargs=optimizer_kwargs,
            squash_output=squash_output,
        )

    def _get_constructor_parameters(self):
        data = super()._get_constructor_parameters()
        del data['lstm_hidden_size']
        del data['lstm_num_layers']
        data.update(
            dict(
                transformer_emb_size=self.transformer_emb_size,
                transformer_num_layers=self.transformer_num_layers,
            )
        )
        return data

    def _build_mlp_extractor(self) -> None:
        self.lstm = TransformerEnc(self.features_dim,
                        d_model=self.transformer_emb_size, 
                        nlayers=self.transformer_num_layers).to(self.device)
        self.mlp_extractor = MlpExtractor(
            self.transformer_emb_size,
            net_arch=self.net_arch,
            activation_fn=self.activation_fn,
            device=self.device,
        )


    def forward(self, obs: th.Tensor, deterministic: bool = False,
            mask = None):
        """
        Forward pass in all the networks (actor and critic)
        :param obs: Observation
        :param deterministic: Whether to sample or use deterministic actions
        :return: action, value and log probability of the action
        """

        #bs, seqlen, obs_dim = obs.shape
        features = self.extract_features(obs)
        features = self.lstm(features, pad_mask = mask)
        if mask is not None:
            features = features[mask]
        
        latent_pi, latent_vf = self.mlp_extractor(features)
        values = self.value_net(latent_vf)
        distribution = self._get_action_dist_from_latent(latent_pi)
        actions = distribution.get_actions(deterministic=deterministic)
        log_prob = distribution.log_prob(actions)
        actions = actions.reshape(features.shape[:-1] + self.action_space.shape)

        return actions, values, log_prob


    def predict(
        self,
        observation: Union[np.ndarray, Dict[str, np.ndarray]],
        masks=None,
        deterministic: bool = False    
        ):

        self.set_training_mode(False)

        vectorized_env = len(observation.shape) == 3
        observation = th.Tensor(observation).to(dtype=th.float32, device=self.device)

        with th.no_grad():
            actions = self._predict(observation, masks=masks, 
                        deterministic=deterministic)
        
        # Convert to numpy, and reshape to the original action shape
        actions = actions.cpu().numpy()
        if len(observation.shape) == 3:
            actions = actions.reshape((observation.shape[:2]) + self.action_space.shape)
        else:
            actions = actions.reshape((-1,) + self.action_space.shape)

        if isinstance(self.action_space, gym.spaces.Box):
            if self.squash_output:
                # Rescale to proper domain when using squashing
                actions = self.unscale_action(actions)
            else:
                # Actions could be on arbitrary scale, so clip the actions to avoid
                # out of bound error (e.g. if sampling from a Gaussian distribution)
                actions = np.clip(actions, self.action_space.low, self.action_space.high)

        # Remove batch dimension if needed
        #if not vectorized_env:
            #actions = actions.squeeze(axis=0)
            #lstm_states = lstm_states.squeeze(axis=0)

        return actions, None

    def _predict(self, observation: th.Tensor, deterministic: bool = False, 
                 masks=None):
        """
        Get the action according to the policy for a given observation.
        :param observation:
        :param deterministic: Whether to use stochastic or deterministic actions
        :return: Taken action according to the policy
        """
        dist = self.get_distribution(observation, masks=masks)
        return dist.get_actions(deterministic=deterministic)

    def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor, masks=None):
        """
        Evaluate actions according to the current policy,
        given the observations.
        :param obs:
        :param actions:
        :return: estimated value, log likelihood of taking those actions
            and entropy of the action distribution.
        """

        features = self.extract_features(obs)
        features = self.lstm(features, pad_mask=masks)

        if masks is not None:
            actions = actions[masks]
            features = features[masks]
        
        latent_pi, latent_vf = self.mlp_extractor(features)
        distribution = self._get_action_dist_from_latent(latent_pi)
        
        log_prob = distribution.log_prob(actions)
        values = self.value_net(latent_vf)
        entropy = distribution.entropy()

        return values, log_prob, entropy

    def get_distribution(self, obs: th.Tensor, masks = None):
        """
        Get the current policy distribution given the observations.
        :param obs:
        :return: the action distribution.
        """

        features = self.extract_features(obs)
        features = self.lstm(features, pad_mask=masks)
        if masks is not None:
            features = features[masks]

        latent_pi = self.mlp_extractor.forward_actor(features)

        return self._get_action_dist_from_latent(latent_pi)