from __future__ import annotations

import torch
from torch import nn


def CNN(image_size: int,
        feature_size: int,  # output features
        hidden_layers: int,
        hidden_channels: int,
        channel_multiplier: int = 1,
        image_channels: int = 3) -> nn.Sequential:
    in_channels = image_channels
    out_channels = hidden_channels
    layers = []
    for _ in range(hidden_layers):
        layers += [nn.Conv2d(in_channels, out_channels, 5, 2, 2),
                   nn.BatchNorm2d(out_channels),
                   nn.LeakyReLU(0.1)]

        in_channels = out_channels
        out_channels *= channel_multiplier

        assert image_size % 2 == 0
        image_size //= 2

    return nn.Sequential(*layers,
                         nn.Conv2d(in_channels, feature_size, image_size),
                         nn.Flatten(-3))


def DCNN(image_size: int,
         feature_size: int,  # input features
         hidden_layers: int,
         hidden_channels: int,
         channel_multiplier: int = 1,
         image_channels: int = 3) -> nn.Sequential:
    multiplier = 2 ** hidden_layers
    assert image_size % multiplier == 0

    channels = hidden_channels * (channel_multiplier ** (hidden_layers - 1))
    layers = [
        nn.Unflatten(-1, (feature_size, 1, 1)),
        nn.ConvTranspose2d(feature_size, channels, image_size // multiplier)
    ]
    for i in reversed(range(hidden_layers)):
        out_channels = channels // channel_multiplier if i else image_channels
        layers += [nn.BatchNorm2d(channels),
                   nn.LeakyReLU(0.1),
                   nn.ConvTranspose2d(channels, out_channels, 5, 2, 2, 1)]

        channels = out_channels

    return nn.Sequential(*layers, nn.Sigmoid())


class AutoEncoder(nn.Module):
    def __init__(self,
                 input_size: int,
                 latent_size: int,
                 hidden_layers: int = 4,
                 hidden_channels: int = 32,
                 channel_multiplier: int = 1,
                 input_channels: int = 3) -> None:
        super().__init__()

        self.input_size = input_size
        self.latent_size = latent_size
        self.hidden_layers = hidden_layers
        self.hidden_channels = hidden_channels
        self.channel_multiplier = channel_multiplier
        self.input_channels = input_channels

        self.encode = CNN(input_size,
                          latent_size,
                          hidden_layers,
                          hidden_channels,
                          channel_multiplier,
                          input_channels)

        self.decode = DCNN(input_size,
                           latent_size,
                           hidden_layers,
                           hidden_channels,
                           channel_multiplier,
                           input_channels)

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        if state.ndim == 3:
            return self.forward(state.unsqueeze(0)).squeeze(0)
        return self.decode(self.encode(state))
