import typing
from typing import Dict, Tuple, Any

import gym
import torch
import torch.nn as nn
from gym.spaces.dict import Dict as SpaceDict

from models.basic_models import RNNActorCritic, LinearActorCritic
from rl_base.common import ActorCriticOutput, DistributionType


class RNNActorCriticWithEmbed(RNNActorCritic):
    def __init__(
        self,
        input_key: str,
        num_embeddings: int,
        embedding_dim: int,
        input_len: int,
        action_space: gym.spaces.Discrete,
        observation_space: SpaceDict,
        num_layers: int = 1,
        rnn_type: str = "GRU",
        head_type=LinearActorCritic,
    ):
        hidden_size = embedding_dim * input_len
        super().__init__(
            input_key=input_key,
            action_space=action_space,
            observation_space=SpaceDict(
                {
                    input_key: gym.spaces.Box(
                        -float("inf"), float("inf"), shape=(hidden_size,)
                    )
                }
            ),
            hidden_size=hidden_size,
            num_layers=num_layers,
            rnn_type=rnn_type,
            head_type=head_type,
        )
        self.initial_embedding = nn.Embedding(
            num_embeddings=num_embeddings, embedding_dim=embedding_dim
        )

    def forward(  # type: ignore
        self,
        observations: Dict[str, torch.FloatTensor],
        recurrent_hidden_states: torch.FloatTensor,
        prev_actions: torch.LongTensor,
        masks: torch.FloatTensor,
        **kwargs
    ) -> Tuple[ActorCriticOutput[DistributionType], Any]:

        obs = typing.cast(
            Dict[str, torch.FloatTensor],
            {
                self.key: self.initial_embedding(observations[self.key]).view(
                    observations[self.key].shape[0], -1
                )
            },
        )
        return super(RNNActorCriticWithEmbed, self).forward(
            observations=obs,
            recurrent_hidden_states=recurrent_hidden_states,
            prev_actions=prev_actions,
            masks=masks,
        )
