import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, List
from typing import Optional


class Encoder(nn.Module):
    """MLP-based encoder for VAE."""
    
    def __init__(self, input_dim: int, hidden_dims: List[int], latent_dim: int):
        super(Encoder, self).__init__()
        
        layers = []
        prev_dim = input_dim
        
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU(),
                nn.BatchNorm1d(hidden_dim)  # Use BatchNorm instead of Dropout
            ])
            prev_dim = hidden_dim
        
        self.feature_extractor = nn.Sequential(*layers)
        
        # Latent space parameters
        self.fc_mu = nn.Linear(prev_dim, latent_dim)
        self.fc_logvar = nn.Linear(prev_dim, latent_dim)
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward pass through encoder."""
        features = self.feature_extractor(x)
        mu = self.fc_mu(features)
        logvar = self.fc_logvar(features)
        return mu, logvar


class Decoder(nn.Module):
    """MLP-based decoder for VAE."""
    
    def __init__(self, latent_dim: int, hidden_dims: List[int], output_dim: int):
        super(Decoder, self).__init__()
        
        layers = []
        prev_dim = latent_dim
        
        # Reverse the hidden dimensions for decoder
        for hidden_dim in reversed(hidden_dims):
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU(),
                nn.BatchNorm1d(hidden_dim)  # Use BatchNorm instead of Dropout
            ])
            prev_dim = hidden_dim
        
        layers.append(nn.Linear(prev_dim, output_dim))
        layers.append(nn.Sigmoid())  # Output in [0, 1] for MNIST
        
        self.decoder = nn.Sequential(*layers)
    
    def forward(self, z: torch.Tensor) -> torch.Tensor:
        """Forward pass through decoder."""
        return self.decoder(z)


class VAE(nn.Module):
    """Variational Autoencoder with MLP architecture."""
    
    def __init__(
        self,
        input_dim: int,
        hidden_dims: List[int],
        latent_dim: int,
        *,
        encoder_hidden_dims: List[int] = None,
        decoder_hidden_dims: List[int] = None,
    ):
        super(VAE, self).__init__()
        
        self.latent_dim = latent_dim
        self.input_dim = input_dim
        
        # Backward compatibility: if per-part hidden dims are not given, use shared hidden_dims
        enc_dims = encoder_hidden_dims if encoder_hidden_dims is not None else hidden_dims
        dec_dims = decoder_hidden_dims if decoder_hidden_dims is not None else hidden_dims

        self.encoder = Encoder(input_dim, enc_dims, latent_dim)
        self.decoder = Decoder(latent_dim, dec_dims, input_dim)
    
    def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        """Reparameterization trick."""
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Forward pass through VAE."""
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decoder(z)
        return x_recon, mu, logvar
    
    def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encode input to latent space."""
        return self.encoder(x)
    
    def decode(self, z: torch.Tensor) -> torch.Tensor:
        """Decode latent representation to output."""
        return self.decoder(z)
    
    def sample(self, num_samples: int, device: torch.device) -> torch.Tensor:
        """Sample from the latent space."""
        z = torch.randn(num_samples, self.latent_dim).to(device)
        return self.decode(z)


def vae_loss(recon_x: torch.Tensor, x: torch.Tensor, mu: torch.Tensor, 
             logvar: torch.Tensor, beta: float = 1.0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Compute VAE loss.
    
    Args:
        recon_x: Reconstructed input
        x: Original input
        mu: Mean of latent distribution
        logvar: Log variance of latent distribution
        beta: Weight for KL divergence (beta-VAE)
    
    Returns:
        Total loss, reconstruction loss, KL divergence
    """
    # Reconstruction loss (binary cross entropy for MNIST)
    # Use mean reduction for more stable training
    recon_loss = F.binary_cross_entropy(recon_x, x, reduction='mean')
    
    # KL divergence (per sample)
    kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    
    # Total loss
    total_loss = recon_loss + beta * kl_loss
    
    return total_loss, recon_loss, kl_loss 


def vae_loss_per_sample(
    recon_x: torch.Tensor,
    x: torch.Tensor,
    mu: torch.Tensor,
    logvar: torch.Tensor,
    beta: float = 1.0,
) -> torch.Tensor:
    """
    Per-sample VAE loss (no reduction across batch).
    Returns a tensor of shape [batch_size].
    """
    # BCE per element, then sum over features (28*28) and not over batch
    bce = F.binary_cross_entropy(recon_x, x, reduction='none')  # [B, D]
    recon_per_sample = bce.view(bce.size(0), -1).mean(dim=1)  # mean over features
    kl_per_sample = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
    total_per_sample = recon_per_sample + beta * kl_per_sample
    return total_per_sample