import torch.nn as nn


class Flatten(nn.Module):
    def forward(self, x):
        batch_size = x.shape[0]
        return x.reshape(batch_size, -1)


def ConvLayer(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(
            in_channels, out_channels,
            kernel_size=4, stride=2
        ),
        nn.LeakyReLU()
    )


class StateEncoder(nn.Module):
    def __init__(self, z_dim=32, channel_dim=3, n_frames=3, use_batch_norm=True):
        super().__init__()

        # Encoder
        if use_batch_norm:
            self.conv_layes = nn.Sequential(
                ConvLayer(channel_dim*n_frames, 32),
                nn.BatchNorm2d(32),
                ConvLayer(32, 32),
                nn.BatchNorm2d(32),
                ConvLayer(32, 64),
                ConvLayer(64, 64),
                Flatten()
            )
        else:
            self.conv_layes = nn.Sequential(
                ConvLayer(channel_dim*n_frames, 32),
                ConvLayer(32, 32),
                ConvLayer(32, 64),
                ConvLayer(64, 64),
                Flatten()
            )

        self.dense_layers = nn.Sequential(nn.Linear(256, 256),
                                          nn.LeakyReLU(),
                                          nn.Linear(256, 128),
                                          nn.LeakyReLU(),
                                          nn.Linear(128, z_dim))

    def forward(self, x):
        x = self.conv_layes(x)
        x = self.dense_layers(x)
        return x

    def set_eval(self):
        for child in self.children():
            if type(child)==nn.Sequential:
                for ii in range(len(child)):
                    if type(child[ii])==nn.BatchNorm2d:
                        child[ii].track_running_stats = False
