import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.unets.unet_2d import UNet2DModel


class UnetDDPM(nn.Module):
    def __init__(
        self,
        in_channels,
        channels,
        layers_per_block,
        downblock,
        upblock,
        add_attention,
        attention_head_dim,
        low_condition,
        timestep_condition,
        global_skip_connection,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.low_condition = low_condition
        self.timestep_condition = timestep_condition
        self.global_skip_connection = global_skip_connection
        self.divide_factor = 2 ** len(channels)

        out_channels = in_channels
        in_channels = in_channels * 2 if low_condition else in_channels

        self.backbone = UNet2DModel(
            in_channels=in_channels,
            out_channels=out_channels,
            block_out_channels=channels,
            layers_per_block=layers_per_block,
            down_block_types=tuple(downblock for _ in range(len(channels))),
            up_block_types=tuple(upblock for _ in range(len(channels))),
            add_attention=add_attention,
            attention_head_dim=attention_head_dim,
        )

    def padding(self, x):
        _, _, W, H = x.shape
        desired_width = (
            (W + self.divide_factor - 1) // self.divide_factor
        ) * self.divide_factor
        desired_height = (
            (H + self.divide_factor - 1) // self.divide_factor
        ) * self.divide_factor

        # Calculate the padding needed
        padding_w = desired_width - W
        padding_h = desired_height - H

        return F.pad(x, (0, padding_h, 0, padding_w), mode="constant", value=0), W, H

    def remove_padding(self, x, W, H):
        return x[:, :, :W, :H]

    def forward(self, x, t, y=None, **kwargs):
        x_in = torch.cat([x, y], dim=1) if self.low_condition else x

        # add padding to fit nearest value divisible by self.divide_factor
        x_in, W, H = self.padding(x_in)

        model_output = self.backbone(
            x_in,
            timestep=t.flatten() if self.timestep_condition else 0,
        ).sample

        model_output = self.remove_padding(model_output, W, H)

        if self.global_skip_connection:
            model_output[:, : self.in_channels] = (
                model_output[:, : self.in_channels] + x
            )

        return model_output
