import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple


def _compute_spatial_dims(initial: int, num_downsamples: int) -> int:
    dim = initial
    for _ in range(num_downsamples):
        dim = (dim + 2*1 - 1*(4-1) - 1) // 2 + 1  # mirror of Conv2d with k=4,s=2,p=1 -> simplifies to dim//2
        dim = dim  # keep integer division
    return dim


class ConvEncoder(nn.Module):
    """CNN encoder for VAE (1x28x28 inputs)."""

    def __init__(self, in_channels: int, channels: List[int], latent_dim: int) -> None:
        super().__init__()
        layers = []
        prev_c = in_channels
        downsamples_used = 0
        spatial = 28

        for idx, out_c in enumerate(channels):
            # Use at most two downsamples to reach 7x7; others keep stride=1
            if downsamples_used < 2:
                layers.append(nn.Conv2d(prev_c, out_c, kernel_size=4, stride=2, padding=1))  # 28->14->7
                spatial = spatial // 2
                downsamples_used += 1
            else:
                layers.append(nn.Conv2d(prev_c, out_c, kernel_size=3, stride=1, padding=1))
            layers.append(nn.ReLU(inplace=True))
            layers.append(nn.BatchNorm2d(out_c))
            prev_c = out_c

        self.conv = nn.Sequential(*layers)
        self.feature_channels = prev_c
        self.feature_size = spatial  # 7 for >=2 downsamples; 14 for 1 downsample; 28 for 0
        flattened_dim = self.feature_channels * self.feature_size * self.feature_size
        self.fc_mu = nn.Linear(flattened_dim, latent_dim)
        self.fc_logvar = nn.Linear(flattened_dim, latent_dim)

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        h = self.conv(x)
        h_flat = h.view(h.size(0), -1)
        mu = self.fc_mu(h_flat)
        logvar = self.fc_logvar(h_flat)
        return h, mu, logvar


class ConvDecoder(nn.Module):
    """CNN decoder for VAE (outputs 1x28x28)."""

    def __init__(self, latent_dim: int, channels: List[int], out_channels: int, feature_channels: int, feature_size: int) -> None:
        super().__init__()
        flattened_dim = feature_channels * feature_size * feature_size
        self.fc = nn.Linear(latent_dim, flattened_dim)

        layers = []
        prev_c = feature_channels
        # Reverse channels for decoding path
        for idx, out_c in enumerate(reversed(channels)):
            # Mirror: first mirror stride=1 layers, then upsample twice
            # Decide stride based on remaining upsamplings needed to reach 28 from feature_size
            # If feature_size < 28, we need n_ups = log2(28/feature_size) upsamplings (assumes 2 at most)
            # Use stride=1 for extra layers beyond these two
            stride = 1
            if feature_size < 28:
                # Available upsamplings remaining depend on current spatial size
                pass  # stride decided per layer order below
            # We will set stride later based on position; default keeps spatial
            layers.append(nn.ConvTranspose2d(prev_c, out_c, kernel_size=3, stride=1, padding=1))
            layers.append(nn.ReLU(inplace=True))
            layers.append(nn.BatchNorm2d(out_c))
            prev_c = out_c

        # Two upsampling stages to return to 28 from 7 (or one if starting at 14)
        self.upsample1 = None
        self.upsample2 = None
        if feature_size <= 7:
            # 7 -> 14
            self.upsample1 = nn.ConvTranspose2d(prev_c, prev_c, kernel_size=4, stride=2, padding=1)
            # 14 -> 28
            self.upsample2 = nn.ConvTranspose2d(prev_c, prev_c, kernel_size=4, stride=2, padding=1)
        elif feature_size == 14:
            # 14 -> 28
            self.upsample1 = nn.ConvTranspose2d(prev_c, prev_c, kernel_size=4, stride=2, padding=1)

        self.deconv = nn.Sequential(*layers)
        self.out = nn.Conv2d(prev_c, out_channels, kernel_size=3, stride=1, padding=1)

    def forward(self, z: torch.Tensor, feature_channels: int, feature_size: int) -> torch.Tensor:
        h = self.fc(z)
        h = h.view(h.size(0), feature_channels, feature_size, feature_size)
        x = self.deconv(h)
        if self.upsample1 is not None:
            x = self.upsample1(x)
        if self.upsample2 is not None:
            x = self.upsample2(x)
        x = self.out(x)
        x = torch.sigmoid(x)
        return x


class ConvVAE(nn.Module):
    def __init__(self, in_channels: int, encoder_channels: List[int], decoder_channels: List[int], latent_dim: int) -> None:
        super().__init__()
        self.latent_dim = latent_dim
        self.in_channels = in_channels
        self.encoder_channels = encoder_channels
        self.decoder_channels = decoder_channels

        self.encoder = ConvEncoder(in_channels, encoder_channels, latent_dim)
        self.feature_channels = self.encoder.feature_channels
        self.feature_size = self.encoder.feature_size
        self.decoder = ConvDecoder(latent_dim, decoder_channels, out_channels=in_channels,
                                   feature_channels=self.feature_channels, feature_size=self.feature_size)

    def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        _, mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decoder(z, self.feature_channels, self.feature_size)
        return x_recon, mu, logvar

    def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        _, mu, logvar = self.encoder(x)
        return mu, logvar

    def decode(self, z: torch.Tensor) -> torch.Tensor:
        return self.decoder(z, self.feature_channels, self.feature_size)

    def sample(self, num_samples: int, device: torch.device) -> torch.Tensor:
        z = torch.randn(num_samples, self.latent_dim, device=device)
        return self.decode(z)


