import gym
import torch
from einops.layers.torch import Rearrange

from net.encoder.base import BaseEncoder


class ConvEncoder(BaseEncoder):
    """
    pixel observation -> latent
    """
    def __init__(self, observation_space: gym.spaces.Box, latent_size: int):
        super().__init__(observation_space, latent_size)

        assert len(observation_space.shape) == 2
        obs_h, obs_w = observation_space.shape

        in_channels = 1
        out_channels = 8
        kernel_size = (1, 3, 3)

        conv_h = obs_h - (kernel_size[1] - 1)
        conv_w = obs_w - (kernel_size[2] - 1)
        conv_out_size = out_channels * conv_h * conv_w

        self.net = torch.nn.Sequential(
            Rearrange('b t h w -> b 1 t h w', h=obs_h, w=obs_w),
            torch.nn.Conv3d(in_channels, out_channels, kernel_size),  # (b, 1, t, 9, 9) -> (b, 8, t, 7, 7)
            torch.nn.ReLU(),
            Rearrange('b c t h w -> b t (c h w)'),
            torch.nn.Linear(conv_out_size, self.latent_size),
            torch.nn.ReLU()
        )

    def forward(self, obs) -> torch.Tensor:
        return self.net(obs)
