import gym
import torch
from einops.layers.torch import Rearrange

from net.encoder.base import BaseEncoder


class DenseEncoder(BaseEncoder):
    """
    pixel observation -> latent
    """
    def __init__(self, observation_space: gym.spaces.Box, latent_size: int):
        super().__init__(observation_space, latent_size)

        assert len(observation_space.shape) == 2
        obs_h, obs_w = observation_space.shape

        self.net = torch.nn.Sequential(
            Rearrange("b t h w -> b t (h w)", h=obs_h, w=obs_w),
            torch.nn.Linear(obs_h * obs_w, self.latent_size),
            torch.nn.ReLU(),
            # torch.nn.Linear(self.latent_size, self.latent_size, bias=False),
            # torch.nn.ReLU()
        )

        with torch.no_grad():
            assert isinstance(self.net[1], torch.nn.Linear)
            torch.nn.init.constant_(self.net[1].weight, 1.0)

            # assert isinstance(self.net[3], torch.nn.Linear)
            # torch.nn.init.constant_(self.net[3].weight, -1.0)
            # self.net[3].weight += torch.eye(self.latent_size) * 2

    def forward(self, obs: torch.Tensor) -> torch.Tensor:
        return self.net(obs)
