import torch
import torch.nn as nn
import torch.nn.functional as F

class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        if x.dim() == 2:
            b, c = x.size()
            x_pooled = x.unsqueeze(-1)
            y = self.avg_pool(x_pooled).view(b, c)
            y = self.fc(y)
            return x * y
        elif x.dim() == 3:
            b, c, _ = x.size()
            y = self.avg_pool(x).squeeze(-1)
            y = self.fc(y)
            return x * y.unsqueeze(-1)
        else:
            raise ValueError(f"Unsupported input dimension: {x.dim()}")

class ResidualBlock(nn.Module):
    def __init__(self, dim, dropout_rate=0.3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim),
            nn.LayerNorm(dim),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(dim, dim),
            nn.LayerNorm(dim)
        )
        self.se = SELayer(dim)
    
    def forward(self, x):
        return self.se(x + self.net(x))

class ResidualBlock_NoSE(nn.Module):
    def __init__(self, dim, dropout_rate=0.3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim),
            nn.LayerNorm(dim),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(dim, dim),
            nn.LayerNorm(dim)
        )
    
    def forward(self, x):
        return x + self.net(x)

class ResidualBlock_NoResidual(nn.Module):
    def __init__(self, dim, dropout_rate=0.3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim),
            nn.LayerNorm(dim),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(dim, dim),
            nn.LayerNorm(dim)
        )
        self.se = SELayer(dim)
    
    def forward(self, x):
        return self.se(self.net(x))

class ResidualBlock_NoResidualNoSE(nn.Module):
    def __init__(self, dim, dropout_rate=0.3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim),
            nn.LayerNorm(dim),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(dim, dim),
            nn.LayerNorm(dim)
        )
        
    def forward(self, x):
        return self.net(x)

class EnhancedCVAE(nn.Module):
    def __init__(self, input_dim, cond_dim, latent_dim, hidden_dim):
        super().__init__()
        
        self.input_dim = input_dim
        self.cond_dim = cond_dim
        self.latent_dim = latent_dim
        self.hidden_dim = hidden_dim
        
        self.condition_mapper = nn.Sequential(
            nn.Linear(cond_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            ResidualBlock(hidden_dim),
            nn.Dropout(0.1)
        )
        
        self.encoder = nn.Sequential(
            nn.Linear(input_dim + hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            ResidualBlock(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim * 2)
        )
        
        self.mu_proj = nn.Linear(hidden_dim * 2, latent_dim)
        self.logvar_proj = nn.Linear(hidden_dim * 2, latent_dim)
        
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim + hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            ResidualBlock(hidden_dim),
            nn.Linear(hidden_dim, input_dim)
        )
        
        self._init_weights()
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def encode(self, x, condition):
        enhanced_cond = self.condition_mapper(condition)
        inputs = torch.cat([x, enhanced_cond], dim=1)
        features = self.encoder(inputs)
        mu = self.mu_proj(features)
        log_var = self.logvar_proj(features)
        return mu, log_var
    
    def decode(self, z, condition):
        enhanced_cond = self.condition_mapper(condition)
        inputs = torch.cat([z, enhanced_cond], dim=1)
        return self.decoder(inputs)
    
    def forward(self, x, condition, epoch=0):
        mu, log_var = self.encode(x, condition)
        z = self.reparameterize(mu, log_var)
        recon_x = self.decode(z, condition)
        return recon_x, mu, log_var
    
    def sample(self, condition, device):
        z = torch.randn(condition.size(0), self.latent_dim).to(device)
        return self.decode(z, condition)

class EnhancedCVAE_NoSE(nn.Module):
    def __init__(self, input_dim, cond_dim, latent_dim, hidden_dim):
        super().__init__()
        
        self.input_dim = input_dim
        self.cond_dim = cond_dim
        self.latent_dim = latent_dim
        self.hidden_dim = hidden_dim
        
        self.condition_mapper = nn.Sequential(
            nn.Linear(cond_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            ResidualBlock_NoSE(hidden_dim),
            nn.Dropout(0.1)
        )
        
        self.encoder = nn.Sequential(
            nn.Linear(input_dim + hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            ResidualBlock_NoSE(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim * 2)
        )
        
        self.mu_proj = nn.Linear(hidden_dim * 2, latent_dim)
        self.logvar_proj = nn.Linear(hidden_dim * 2, latent_dim)
        
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim + hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            ResidualBlock_NoSE(hidden_dim),
            nn.Linear(hidden_dim, input_dim)
        )
        
        self._init_weights()
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def encode(self, x, condition):
        enhanced_cond = self.condition_mapper(condition)
        inputs = torch.cat([x, enhanced_cond], dim=1)
        features = self.encoder(inputs)
        mu = self.mu_proj(features)
        log_var = self.logvar_proj(features)
        return mu, log_var
    
    def decode(self, z, condition):
        enhanced_cond = self.condition_mapper(condition)
        inputs = torch.cat([z, enhanced_cond], dim=1)
        return self.decoder(inputs)
    
    def forward(self, x, condition, epoch=0):
        mu, log_var = self.encode(x, condition)
        z = self.reparameterize(mu, log_var)
        recon_x = self.decode(z, condition)
        return recon_x, mu, log_var
    
    def sample(self, condition, device):
        z = torch.randn(condition.size(0), self.latent_dim).to(device)
        return self.decode(z, condition)

class EnhancedCVAE_NoResidual(nn.Module):
    def __init__(self, input_dim, cond_dim, latent_dim, hidden_dim):
        super().__init__()
        
        self.input_dim = input_dim
        self.cond_dim = cond_dim
        self.latent_dim = latent_dim
        self.hidden_dim = hidden_dim
        
        self.condition_mapper = nn.Sequential(
            nn.Linear(cond_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            ResidualBlock_NoResidual(hidden_dim),
            nn.Dropout(0.1)
        )
        
        self.encoder = nn.Sequential(
            nn.Linear(input_dim + hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            ResidualBlock_NoResidual(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim * 2)
        )
        
        self.mu_proj = nn.Linear(hidden_dim * 2, latent_dim)
        self.logvar_proj = nn.Linear(hidden_dim * 2, latent_dim)
        
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim + hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            ResidualBlock_NoResidual(hidden_dim),
            nn.Linear(hidden_dim, input_dim)
        )
        
        self._init_weights()
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def encode(self, x, condition):
        enhanced_cond = self.condition_mapper(condition)
        inputs = torch.cat([x, enhanced_cond], dim=1)
        features = self.encoder(inputs)
        mu = self.mu_proj(features)
        log_var = self.logvar_proj(features)
        return mu, log_var
    
    def decode(self, z, condition):
        enhanced_cond = self.condition_mapper(condition)
        inputs = torch.cat([z, enhanced_cond], dim=1)
        return self.decoder(inputs)
    
    def forward(self, x, condition, epoch=0):
        mu, log_var = self.encode(x, condition)
        z = self.reparameterize(mu, log_var)
        recon_x = self.decode(z, condition)
        return recon_x, mu, log_var
    
    def sample(self, condition, device):
        z = torch.randn(condition.size(0), self.latent_dim).to(device)
        return self.decode(z, condition)

class EnhancedCVAE_NoResidualNoSE(nn.Module):
    def __init__(self, input_dim, cond_dim, latent_dim, hidden_dim):
        super().__init__()
        
        self.input_dim = input_dim
        self.cond_dim = cond_dim
        self.latent_dim = latent_dim
        self.hidden_dim = hidden_dim
        
        self.condition_mapper = nn.Sequential(
            nn.Linear(cond_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            ResidualBlock_NoResidualNoSE(hidden_dim),
            nn.Dropout(0.1)
        )
        
        self.encoder = nn.Sequential(
            nn.Linear(input_dim + hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            ResidualBlock_NoResidualNoSE(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim * 2)
        )
        
        self.mu_proj = nn.Linear(hidden_dim * 2, latent_dim)
        self.logvar_proj = nn.Linear(hidden_dim * 2, latent_dim)
        
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim + hidden_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            ResidualBlock_NoResidualNoSE(hidden_dim),
            nn.Linear(hidden_dim, input_dim)
        )
        
        self._init_weights()
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def encode(self, x, condition):
        enhanced_cond = self.condition_mapper(condition)
        inputs = torch.cat([x, enhanced_cond], dim=1)
        features = self.encoder(inputs)
        mu = self.mu_proj(features)
        log_var = self.logvar_proj(features)
        return mu, log_var
    
    def decode(self, z, condition):
        enhanced_cond = self.condition_mapper(condition)
        inputs = torch.cat([z, enhanced_cond], dim=1)
        return self.decoder(inputs)
    
    def forward(self, x, condition, epoch=0):
        mu, log_var = self.encode(x, condition)
        z = self.reparameterize(mu, log_var)
        recon_x = self.decode(z, condition)
        return recon_x, mu, log_var
    
    def sample(self, condition, device):
        z = torch.randn(condition.size(0), self.latent_dim).to(device)
        return self.decode(z, condition)
