import torch
import torch.nn as nn
from torch.nn import functional as F

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, activation=F.silu):  # Swish = SiLU
        super().__init__()
        self.activation = activation
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.skip = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity()

    def forward(self, x):
        identity = self.skip(x)
        out = self.activation(self.conv1(x))
        out = self.conv2(out)
        return self.activation(out + identity)
    
class ResidualBlock1Conv(nn.Module):
    def __init__(self, channels, activation=F.silu):
        super().__init__()
        self.activation = activation
        self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)

    def forward(self, x):
        return self.activation(self.conv(x) + x)

class baseline(nn.Module): 
    def __init__(self, n_dim=1, m_dim=1, nu=3, recon_sigma=1, reg_weight=1, num_layers=64, device='cpu'):
        super(baseline, self).__init__()
        self.model_name = None

        self.n_dim = n_dim
        self.m_dim = m_dim
        self.recon_sigma = recon_sigma
        self.reg_weight = reg_weight
        self.num_layers = num_layers
        self.device = device


        self.fc = nn.Linear(self.m_dim, 512)
        
        # define encoder
        self.encoder = nn.Sequential(
            *[ResidualBlock(1 if i == 0 else 32, 32) for i in range(7)],
            nn.AdaptiveAvgPool2d((4, 4))
        )
        self.ffn = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32 * 4 * 4, 256),
            nn.LayerNorm(256),
            nn.SiLU(),
        )

        # define decoder

        self.fc = nn.Linear(self.m_dim, 512)
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        self.final = nn.Conv2d(32, 1, kernel_size=3, padding=1)
        # Decoder (CNN 기반)
        self.decoder_block = nn.ModuleList([
            ResidualBlock1Conv(32) for _ in range(10)
        ])

    def decode(self, z) :
        x = F.silu(self.fc(z)).view(-1, 32, 4, 4)
        for i, block in enumerate(self.decoder_block):
            x = block(x)
            if i in [2, 5, 8]:  # upsample 3 times over the 10 blocks
                x = self.upsample(x)
        return self.final(x)
    

