import torch as th
import torch.nn as nn
import gym
from collections import OrderedDict
from stable_baselines3.common.preprocessing import is_image_space

class ObservationsExtractor(nn.Module):
    
    def __init__(self, observation_space: gym.Space, output_dim: int = 512) -> None:
        super().__init__()

        assert is_image_space(observation_space, check_channels=False), (
            "This `ObservationsExtractor` is meant to be used with `ActorCriticCnnPolicy`"
        )

        self._observation_space = observation_space
        self.output_dim = output_dim

        n_input_channels = observation_space.shape[0]
        self.cnn = nn.Sequential(
            nn.BatchNorm2d(n_input_channels),
            nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
            nn.ReLU(),
            nn.Flatten(),
        )

        self.linear = None
        self._build()

    def _build(self):
        if self.linear is None:
            with th.no_grad():
                n_flatten = self.cnn(th.as_tensor(self._observation_space.sample()[None]).float()).shape[1]
            self.linear = nn.Sequential(nn.Linear(n_flatten, self.output_dim), nn.ReLU())

    def forward(self, observations: th.Tensor) -> th.Tensor:
        return self.linear(self.cnn(observations))

    def load(self, model, load_linear=False):
        self.cnn = model.cnn

        if load_linear:
            self.linear = model.linear

class ActionsExtractor(nn.Module):

    def __init__(self, action_space: gym.Space, output_dim: int = 32) -> None:
        super().__init__()

        self._action_space = action_space
        self.output_dim = output_dim

        self.flattener = nn.Flatten()
        self.linear = None

        self._build()

    def _build(self):
        with th.no_grad():
            actions_sample = self._transform(th.as_tensor(self._action_space.sample())[None])
            n_flatten = self.flattener(actions_sample).shape[1]

        self.linear = nn.Sequential(nn.Linear(n_flatten, self.output_dim), nn.ReLU())

    def _transform(self, actions: th.Tensor) -> th.Tensor:
        if isinstance(self._action_space, gym.spaces.Discrete):
            return nn.functional.one_hot(actions.long(), self._action_space.n)
        return actions

    def forward(self, actions: th.Tensor) -> th.Tensor:
        return self.linear(self.flattener(self._transform(actions)).float())

    def load(self, model, load_linear=False):
        pass
