import torch
import torch.nn as nn
import torch.optim as optim

import torch
import torch.nn as nn

import torch
import torch.nn as nn

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()

        # Encoder with depth 4, outputting (1, 16, 16)
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),    # (batch_size, 64, 32, 32)
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),   # (batch_size, 128, 16, 16)
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),   # (batch_size, 64, 16, 16)
            nn.ReLU(),
            nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1),     # (batch_size, 1, 16, 16)
            nn.ReLU(),
        )

        # Decoder to reconstruct the input
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(1, 64, kernel_size=3, stride=1, padding=1),    # (batch_size, 64, 16, 16)
            nn.ReLU(),
            nn.ConvTranspose2d(64, 128, kernel_size=4, stride=2, padding=1),  # (batch_size, 128, 32, 32)
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # (batch_size, 64, 64, 64)
            nn.ReLU(),
            nn.ConvTranspose2d(64, 1, kernel_size=3, stride=1, padding=1),    # (batch_size, 1, 64, 64)
            nn.Sigmoid(),  # Output pixel values between 0 and 1
        )

    def forward(self, x):
        # Encode the input (output shape will be (batch_size, 1, 16, 16))
        encoded = self.encoder(x)
        
        # Decode the latent representation back to the original size
        reconstructed = self.decoder(encoded)
        return reconstructed

import torch
import torch.nn as nn

class SmallAutoencoder(nn.Module):
    def __init__(self):
        super(SmallAutoencoder, self).__init__()

        # Encoder with 2 layers, outputting (1, 16, 16)
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),    # (batch_size, 32, 32, 32)
            nn.ReLU(),
            nn.Conv2d(32, 1, kernel_size=3, stride=2, padding=1),    # (batch_size, 1, 16, 16)
            nn.ReLU(),
        )

        # Decoder to reconstruct the input
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(1, 32, kernel_size=3, stride=2, padding=1, output_padding=1),  # (batch_size, 32, 32, 32)
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1),  # (batch_size, 1, 64, 64)
            nn.Sigmoid(),  # Output pixel values between 0 and 1
        )

    def forward(self, x):
        # Encode the input (output shape will be (batch_size, 1, 16, 16))
        encoded = self.encoder(x)
        
        # Decode the latent representation back to the original size
        reconstructed = self.decoder(encoded)
        return reconstructed

import torch
import torch.nn as nn

class Autoencoder32(nn.Module):
    def __init__(self):
        super(Autoencoder32, self).__init__()

        # Encoder with 2 layers, outputting (1, 32, 32) with kernel size 3
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),    # (batch_size, 32, 32, 32)
            nn.ReLU(),
            nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1),    # (batch_size, 1, 32, 32)
            nn.Sigmoid(),
        )

        # Decoder to reconstruct the input
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(1, 32, kernel_size=3, stride=1, padding=1),   # (batch_size, 32, 32, 32)
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1),  # (batch_size, 1, 64, 64)
            nn.Sigmoid(),  # Output pixel values between 0 and 1
        )

    def forward(self, x, noise_std=0.5):
        # Encode the input (output shape will be (batch_size, 1, 32, 32))
        encoded = self.encoder(x)

        if self.training:
            # Add Gaussian noise to the latent representation
            noise = torch.randn_like(encoded) * noise_std
            encoded = encoded + noise
        
        # Decode the latent representation back to the original size
        reconstructed = self.decoder(encoded)
        return reconstructed

# Initialize the autoencoder
#️ autoencoder = Autoencoder()

# Define loss function and optimizer
# criterion = nn.MSELoss()  # Use MSE loss for reconstruction
# optimizer = optim.Adam(autoencoder.parameters(), lr=0.001)

# Example input: a batch of 64x64 grayscale images
# batch_size = 16
# images = torch.randn(batch_size, 1, 64, 64)  # Simulated input batch

# # Forward pass
# reconstructed = autoencoder(images)

# # Compute the loss
# loss = criterion(reconstructed, images)

# # Backpropagation and optimization
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()
