import torch
from torch import nn
from torch.nn import functional as F

from vaetc.models.vae import VAE
from vaetc.network.blocks import SigmoidInverse
from vaetc.data.utils import IMAGE_CHANNELS, IMAGE_HEIGHT, IMAGE_WIDTH

class Encoder(nn.Module):

    def __init__(self, hidden_channels: list[int], hidden_elements: int, z_dim: int) -> None:
        super().__init__()

        assert IMAGE_HEIGHT % (2 ** len(hidden_channels)) == 0
        assert IMAGE_WIDTH % (2 ** len(hidden_channels)) == 0

        Activation = nn.GELU # alias

        downsampling_blocks = [
            nn.Sequential(
                nn.Conv2d(cin, cout, 4, 2, 1),
                Activation(),
                nn.BatchNorm2d(cout),
            )
            for cin, cout in zip([IMAGE_CHANNELS] + hidden_channels[:-1], hidden_channels)
        ]
        downsampled_height = IMAGE_HEIGHT // (2 ** len(hidden_channels))
        downsampled_width  = IMAGE_WIDTH  // (2 ** len(hidden_channels))
        downsampled_channel = hidden_channels[-1]
        downsampled_nelm = downsampled_channel * downsampled_height * downsampled_width

        self.net = nn.Sequential(
            SigmoidInverse(),
            *downsampling_blocks,
            nn.Flatten(),
            nn.Linear(downsampled_nelm, hidden_elements),
            Activation(),
        )

        self.fc_mean   = nn.Linear(hidden_elements, z_dim)
        self.fc_logvar = nn.Linear(hidden_elements, z_dim)

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:

        h = self.net(x)
        return self.fc_mean(h), self.fc_logvar(h)

class Decoder(nn.Module):

    def __init__(self, hidden_channels: list[int], hidden_elements: int, z_dim: int) -> None:
        super().__init__()

        assert IMAGE_HEIGHT % (2 ** len(hidden_channels)) == 0
        assert IMAGE_WIDTH % (2 ** len(hidden_channels)) == 0

        Activation = nn.GELU # alias

        upsampling_blocks = [
            nn.Sequential(
                Activation() if i > 0 else nn.Identity(),
                nn.BatchNorm2d(cin),
                nn.ConvTranspose2d(cin, cout, 4, 2, 1),
            )
            for i, (cin, cout) in enumerate(zip(hidden_channels, hidden_channels[1:] + [IMAGE_CHANNELS]))
        ]
        downsampled_height = IMAGE_HEIGHT // (2 ** len(hidden_channels))
        downsampled_width  = IMAGE_WIDTH  // (2 ** len(hidden_channels))
        downsampled_channel = hidden_channels[0]
        downsampled_nelm = downsampled_channel * downsampled_height * downsampled_width

        self.net = nn.Sequential(
            nn.Linear(z_dim, hidden_elements),
            Activation(),
            nn.Linear(hidden_elements, downsampled_nelm),
            nn.Unflatten(dim=1, unflattened_size=(downsampled_channel, downsampled_height, downsampled_width)),
            *upsampling_blocks,
            nn.Sigmoid(),
        )

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:

        return self.net(x)

class DeepVAE(VAE):

    def __init__(self, hyperparameters: dict):
        super().__init__(hyperparameters)

        self.hidden_channels = list(hyperparameters.get("hidden_channels", [32, 64, 128, 256]))
        self.hidden_elements = int(hyperparameters.get("hidden_elements", 256))
        
        self.enc_block = Encoder(
            hidden_channels=self.hidden_channels,
            hidden_elements=self.hidden_elements,
            z_dim=self.z_dim,
        )

        self.dec_block = Decoder(
            hidden_channels=self.hidden_channels[::-1],
            hidden_elements=self.hidden_elements,
            z_dim=self.z_dim,
        )
