from diffusers import UNet2DModel
import torch
import torch.nn as nn


class DiffuserUNet2DModelforMNIST(nn.Module):
    '''
    simple unet design without attention
    '''

    def __init__(self, sample_size, n_channels):
        super().__init__()

        self.model = UNet2DModel(
            in_channels=1,
            out_channels=1,
            sample_size=32,
            block_out_channels=(32,64,128,256),
            norm_num_groups=8
        )

    def forward(self, x, t):
        unet2doutput = self.model(x, t)
        return unet2doutput["sample"]
