import torch
from torch import nn
from torch.nn import init


class DenseBlock(nn.Module):
    def __init__(self, in_dim, growth_rate, time_dim, dropout=0.1):
        super().__init__()

        # Main MLP layer - now takes concatenated input
        self.layer = nn.Sequential(
            nn.LayerNorm(in_dim + time_dim),  # LayerNorm over concatenated features
            nn.Linear(in_dim + time_dim, growth_rate),
            nn.ReLU(),
            nn.Dropout(dropout),
        )

    def forward(self, x, t_emb):
        # Concatenate input with time embedding
        h = torch.cat([x, t_emb], dim=1)

        # Process through layer
        h = self.layer(h)

        return h


class MLPDenoiser(nn.Module):
    def __init__(self, num_blocks=6, growth_rate=512, dropout=0.1):
        super().__init__()

        # Input size: 3*32*32 = 3072
        self.input_dim = 3 * 32 * 32
        self.time_dim = 256  # Store time dimension

        # Time embedding
        self.time_embed = nn.Sequential(
            nn.Linear(1, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
        )

        # Initial projection - now takes concatenated input
        self.init_proj = nn.Sequential(
            nn.Linear(self.input_dim + self.time_dim, growth_rate),
            nn.LayerNorm(growth_rate),
            nn.ReLU(),
        )

        # Dense blocks
        self.blocks = nn.ModuleList()
        current_dim = growth_rate

        for _ in range(num_blocks):
            block = DenseBlock(
                in_dim=current_dim,
                growth_rate=growth_rate,
                time_dim=self.time_dim,
                dropout=dropout,
            )
            self.blocks.append(block)
            current_dim += growth_rate

        # Final layer to reconstruct image
        self.final_layer = nn.Sequential(
            nn.LayerNorm(current_dim), nn.Linear(current_dim, self.input_dim)
        )

        self.initialize()

    def initialize(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                init.xavier_uniform_(module.weight)
                init.zeros_(module.bias)

    def forward(self, x, t):
        # Flatten input
        batch_size = x.shape[0]
        x = x.view(batch_size, -1)

        # Process time embedding
        t = t.float().view(-1, 1) / 1000.0  # Normalize timestep
        t_emb = self.time_embed(t)

        # Concatenate input with time embedding for initial projection
        x = torch.cat([x, t_emb], dim=1)
        features = [self.init_proj(x)]

        # Pass through dense blocks
        for block in self.blocks:
            # Concatenate all previous features
            x = torch.cat(features, dim=1)

            # Process through block
            h = block(x, t_emb)

            # Add to feature list
            features.append(h)

        # Concatenate all features for final layer
        x = torch.cat(features, dim=1)

        # Final layer
        x = self.final_layer(x)

        # Reshape back to image
        x = x.view(batch_size, 3, 32, 32)
        return x


if __name__ == "__main__":
    # Test configuration similar to training setup
    config = {"T": 1000, "growth_rate": 512 * 16, "dropout": 0.15}

    batch_size = 8

    # Test original MLPDenoiser
    print("Testing original MLPDenoiser:")
    model1 = MLPDenoiser(
        num_blocks=25, growth_rate=config["growth_rate"], dropout=config["dropout"]
    )
    x = torch.randn(batch_size, 3, 32, 32)
    t = torch.randint(config["T"], (batch_size,))
    y = model1(x, t)
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {y.shape}")
    total_params = sum(p.numel() for p in model1.parameters())
    print(f"Total parameters: {total_params:,}")
