import torch
from torch import nn
from torch.nn import init

from .mlp_model import MLPDenoiser


class SimpleMLP(nn.Module):
    def __init__(self, input_dim=3072, hidden_dim=512, num_layers=8, dropout=0.1):
        super().__init__()

        layers = []
        current_dim = input_dim

        # Add hidden layers
        for _ in range(num_layers - 1):
            layers.extend(
                [
                    nn.Linear(current_dim, hidden_dim),
                    nn.LayerNorm(hidden_dim),
                    nn.ReLU(),
                    nn.Dropout(dropout),
                ]
            )
            current_dim = hidden_dim

        # Add final layer
        layers.append(nn.Linear(current_dim, input_dim))

        self.network = nn.Sequential(*layers)
        self.initialize()

    def initialize(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                init.xavier_uniform_(module.weight)
                init.zeros_(module.bias)

    def forward(self, x):
        batch_size = x.shape[0]
        x = x.view(batch_size, -1)
        x = self.network(x)
        return x.view(batch_size, 3, 32, 32)


class DiscretizedMLPDenoiser(nn.Module):
    def __init__(
        self,
        time_step=50,
        max_time=1000,
        hidden_dim=512,
        num_layers=4,
        dropout=0.1,
        growth_rate=512,
        num_blocks=6,
    ):
        """
        Args:
            time_step: Size of each time interval
            max_time: Maximum timestep value (T)
            hidden_dim: Hidden dimension for each MLP
            num_layers: Number of layers in each MLP
            dropout: Dropout rate
            growth_rate: Ignored, kept for interface compatibility
            num_blocks: Ignored, kept for interface compatibility
        """
        super().__init__()

        print(time_step, max_time, hidden_dim, num_layers, dropout)

        self.time_step = time_step
        self.max_time = max_time
        self.input_dim = 3 * 32 * 32

        # Calculate number of MLPs needed
        num_intervals = (max_time // time_step) + 1

        # Create MLPs for each time interval
        self.mlps = nn.ModuleDict()
        for t in range(0, max_time + 1, time_step):
            interval_key = str(t)
            # self.mlps[interval_key] = SimpleMLP(
            #     input_dim=self.input_dim,
            #     hidden_dim=hidden_dim,
            #     num_layers=num_layers,
            #     dropout=dropout
            # )
            self.mlps[interval_key] = MLPDenoiser(
                num_blocks=num_blocks, growth_rate=growth_rate, dropout=dropout
            )

    def get_interval_key(self, t):
        # Find the appropriate time interval for the given timestep
        t = t.item() if torch.is_tensor(t) else t
        interval = (t // self.time_step) * self.time_step
        return str(min(interval, self.max_time))

    def forward(self, x, t):
        if len(t.shape) > 0:
            # If we get a batch of timesteps, we need to process each one separately
            outputs = []
            for i in range(len(t)):
                interval_key = self.get_interval_key(t[i])
                mlp = self.mlps[interval_key]
                outputs.append(mlp(x[i : i + 1], t[i : i + 1]))
            return torch.cat(outputs, dim=0)
        else:
            # Single timestep
            interval_key = self.get_interval_key(t)
            mlp = self.mlps[interval_key]
            return mlp(x, t)

    @classmethod
    def from_config(cls, config):
        """Create model from configuration dictionary"""
        return cls(
            time_step=50,  # Fixed interval size
            max_time=config["T"],
            hidden_dim=config.get("hidden_dim", 512),  # Use hidden_dim from config
            num_layers=config.get("num_layers", 6),  # Use num_layers from config
            growth_rate=config.get("growth_rate", 512),
            num_blocks=config.get("num_blocks", 6),
            dropout=config.get("dropout", 0.1),
        )
