import torch
import torch.nn as nn
from torch import Tensor

from models.KarrasUnet.nets import KarrasUnet

__all__ = [
    'Unet'
]

class Unet(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, **kwargs):
        super(Unet, self).__init__()
        self.cnn = KarrasUnet(
            image_size=input_size,
            dim=hidden_size,
            in_channels=2,
            out_channels=1
        )

    def forward(self, x: Tensor, t: Tensor, condition: Tensor) -> Tensor:
        # x.shape == (batch_size, input_size, input_size)
        # t.shape == (batch_size, )
        # condition.shape == (batch_size, input_size, input_size)

        input_img = torch.stack([x, condition], dim=1)
        return self.cnn(input_img, t).squeeze(dim=1)
