import torch.nn as nn

from .commons import Conv, ResidualConv, NormalisePixels, ChannelFirst
from ..simulators import SIMULATOR, Navigation, Procgen


def architecture(d_hidden):
    if SIMULATOR == Navigation:
        return nn.Sequential(
            Conv(3, 64, 3, 1, 1),
            ResidualConv(64),
            ResidualConv(64),
            ResidualConv(64),
            ResidualConv(64),
            nn.Conv2d(64, d_hidden, 1),
            nn.Tanh(),
        )
    elif SIMULATOR == Procgen:
        return nn.Sequential(
            NormalisePixels(),
            ChannelFirst(),
            Conv(3, 16, 3, 1, 1),
            nn.MaxPool2d(2, ceil_mode=True),
            ResidualConv(16),
            ResidualConv(16),
            Conv(16, 32, 3, 1, 1),
            nn.MaxPool2d(2, ceil_mode=True),
            ResidualConv(32),
            ResidualConv(32),
            Conv(32, 32, 3, 1, 1),
            nn.MaxPool2d(2, ceil_mode=True),
            ResidualConv(32),
            ResidualConv(32),
            nn.Conv2d(32, d_hidden, 1),
            nn.Tanh(),
        )
    else:
        raise NotImplementedError("Encoder network not implemented for the current SIMULATOR!")


class Encoder(nn.Module):
    def __init__(
        self,
        d_hidden,
    ):
        super().__init__()
        self.layers = architecture(d_hidden)

    def forward(
        self,
        states,
    ):
        return self.layers(states)
