from diffusers import UNet2DModel, UNet1DModel
import torch
import torch.nn as nn


class DiffuserUNet2DModelforMNIST(nn.Module):
    '''
    simple unet design without attention
    '''

    def __init__(self, sample_size, n_channels):
        super().__init__()

        self.model = UNet2DModel(
            in_channels=n_channels,
            out_channels=n_channels,
            sample_size=sample_size,
            down_block_types=("DownBlock2D", "DownBlock2D", "DownBlock2D"),
            up_block_types=("UpBlock2D", "UpBlock2D", "UpBlock2D"),
            block_out_channels=(32, 64, 128),  # Adjusted for MNIST
            norm_num_groups=4,
            downsample_padding=1  # To help maintain shape consistency
        )

    def forward(self, x, t):
        unet2doutput = self.model(x, t)
        return unet2doutput["sample"]

class DiffuserUNet2DModel(nn.Module):
    '''
    simple unet design without attention
    '''

    def __init__(self, sample_size, n_channels, block_out_channels=(128, 128, 256, 256, 512, 512),down_block_types=(
                "DownBlock2D",  # a regular ResNet downsampling block
                "DownBlock2D",
                "DownBlock2D",
                "DownBlock2D",
                "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
                "DownBlock2D",
            ),
            up_block_types=(
                "UpBlock2D",  # a regular ResNet upsampling block
                "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
                "UpBlock2D",
                "UpBlock2D",
                "UpBlock2D",
                "UpBlock2D",
            ),):
        super().__init__()

        self.model = UNet2DModel(
            sample_size=sample_size,  # the target image resolution
            in_channels=n_channels,  # the number of input channels, 3 for RGB images
            out_channels=n_channels,  # the number of output channels
            layers_per_block=2,  # how many ResNet layers to use per UNet block
            block_out_channels=block_out_channels,  # the number of output channels for each UNet block
            down_block_types=down_block_types,
            up_block_types=up_block_types
        )

    def forward(self, x, t):
        unet2doutput = self.model(x, t)
        return unet2doutput["sample"]

import torch
import torch.nn as nn
from diffusers import UNet2DModel

class DiffuserUNet1DModel(nn.Module):
    """
    Simple UNet design without attention.
    Expands input from (batch_size, 2) to (batch_size, 2, 2, 2) before passing it through the model.
    The output is then reduced back to (batch_size, 2).
    """

    def __init__(self, sample_size, n_channels,
                 block_out_channels=(128, 128, 256, 256, 512, 512),
                 down_block_types=(
                     "DownBlock2D", "DownBlock2D", "DownBlock2D",
                     "DownBlock2D", "AttnDownBlock2D", "DownBlock2D"),
                 up_block_types=(
                     "UpBlock2D", "AttnUpBlock2D", "UpBlock2D",
                     "UpBlock2D", "UpBlock2D", "UpBlock2D")):
        super().__init__()

        self.model = UNet2DModel(
            sample_size=sample_size,  # Target image resolution
            in_channels=n_channels,  # Number of input channels
            out_channels=n_channels,  # Number of output channels
            layers_per_block=2,  # Number of ResNet layers per block
            block_out_channels=block_out_channels,  # Output channels for each block
            down_block_types=down_block_types,
            up_block_types=up_block_types
        )

    def forward(self, x, t):
        """
        Args:
            x: Tensor of shape (batch_size, 2)
            t: Time step tensor
        Returns:
            Output tensor of shape (batch_size, 2)
        """
        batch_size = x.shape[0]

        # Expand input from (batch_size, 2) -> (batch_size, 2, 2, 2)
        x = x.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, 2, 2)  # Expands last two dimensions

        # Pass through UNet model
        unet2d_output = self.model(x, t)["sample"]  # (batch_size, 2, 2, 2)

        # Reduce output back to (batch_size, 2)
        x_reduced = unet2d_output.mean(dim=(2, 3))  # Average pooling over spatial dimensions

        return x_reduced
