import torch
import torch.nn as nn
from torch import Tensor

from models.KarrasUnet.nets import KarrasUnet

__all__ = [
    'Unet'
]

class Unet(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, **kwargs):
        super(Unet, self).__init__()
        self.cnn = KarrasUnet(
            image_size=input_size,
            dim=hidden_size,
            in_channels=1,
            out_channels=1
        )

    def forward(self, x: Tensor, t: Tensor) -> Tensor:
        # x.shape == (batch_size, input_size, input_size)
        # t.shape == (batch_size, )

        return self.cnn(x.unsqueeze(dim=1), t).squeeze(dim=1)
