import torch
import torch.nn as nn


class Encoder(nn.Module):
    """
    Encoder network.

    Input shape: (batch_size, 3, input_height, input_width)
    Output shape: (batch_size, latent_dim)
    """
    def __init__(self, input_height: int = 28, input_width: int = 28, latent_dim: int = 2):
        super().__init__()
        self.input_height = input_height
        self.input_width = input_width
        self.latent_dim = latent_dim

        # Convolutional layers
        self.conv_layers = nn.Sequential(

            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),

            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),

            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2),

            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
            # nn.ReLU(inplace=True)
        )

        # Fully connected layer
        self.fc = nn.Linear(in_features=128, out_features=latent_dim)


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the encoder.

        Args:
            x (torch.Tensor): Input tensor with shape (batch_size, 1, input_height, input_width).

        Returns:
            torch.Tensor: Latent representation tensor with shape (batch_size, latent_dim).
        """

        # x = x.expand(x.shape[0], 3, self.input_height, self.input_width)
        x = self.conv_layers(x)
        # Flatten the output for the fully connected layer
        x = x.view(x.size(0), -1) # Shape: (batch_size, 128)
        x = self.fc(x)

        return x


class Decoder(nn.Module):
    """
    Decoder network.

    Input shape: (batch_size, latent_dim)
    Output shape: (batch_size, 3, output_height, output_width)
    """
    def __init__(self, output_height: int = 28, output_width: int = 28, latent_dim: int = 2):
        super().__init__()
        self.output_height = output_height
        self.output_width = output_width
        self.latent_dim = latent_dim

        # Initial Linear layer
        self.fc = nn.Linear(latent_dim, 128)

        # Define the convolutional/upsampling blocks sequentially
        self.decoder_blocks = nn.Sequential(
            # Output: (B, 128, 1, 1)
            nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=2),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='nearest'),

            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=1, padding=2),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=4, mode='nearest'),

            nn.Conv2d(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the decoder.

        Args:
            x (torch.Tensor): Latent representation tensor with shape (batch_size, latent_dim).

        Returns:
            torch.Tensor: Reconstructed output tensor with shape (batch_size, 3, output_height, output_width).
        """
        x = self.fc(x)
        # Reshape to spatial dimensions for convolutional layers
        # Shape: (batch_size, 128, 1, 1)
        x = x.view(x.size(0), 128, 1, 1)
        x = self.decoder_blocks(x)
        return x

# --- Example Usage  ---
if __name__ == '__main__':
    img_height = 28
    img_width = 28
    latent_dimensions = 2
    batch_size = 4

    # Create dummy input data
    dummy_input = torch.randn(batch_size, 3, img_height, img_width)

    # Instantiate networks
    try:
        encoder = Encoder(input_height=img_height, input_width=img_width, latent_dim=latent_dimensions)
        decoder = Decoder(output_height=img_height, output_width=img_width, latent_dim=latent_dimensions)

        print("--- Encoder Architecture ---")
        print(encoder)
        print("\n--- Decoder Architecture ---")
        print(decoder)

        # Test forward pass
        latent_vec = encoder(dummy_input)
        print(f"\nInput shape: {dummy_input.shape}")
        print(f"Latent vector shape: {latent_vec.shape}")

        reconstructed_output = decoder(latent_vec)
        print(f"Reconstructed output shape: {reconstructed_output.shape}")

        # Verify output size matches expected
        assert reconstructed_output.shape == torch.Size([batch_size, 3, img_height, img_width])
        print("\nForward pass successful, output shape matches expected output.")

    except ValueError as e:
        print(f"\nError initializing networks: {e}")
        print("Please ensure input/output dimensions are large enough for the pooling/upsampling.")