from typing import Iterable

import torch
from jaxtyping import jaxtyped, Float, Int
from torch import nn
from beartype import beartype

from src.Backbones.utils_backbone_unet import (
    SinusoidalPosEmb,
    exists,
    Encoder, Middle, Decoder,
)


class Unet(nn.Module):
    def __init__(
        self,
        dim: int,
        time_dim: int,
        init_dim: int,
        out_dim: int,
        dim_mults: Iterable[int],
        channels: int,
        resnet_block_groups: int,
        with_time_emb: bool,
        residual: bool,
        use_convnext: bool,
        convnext_mult: int,
    ):
        super().__init__()

        mid_dim = dim * dim_mults[-1]

        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(dim),
            nn.Linear(dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim),
        ) if with_time_emb else None

        self.encoder = Encoder(
            dim=dim,
            init_dim=init_dim,
            time_dim=time_dim,
            dim_mults=dim_mults,
            channels=channels,
            resnet_block_groups=resnet_block_groups,
            use_convnext=use_convnext,
            convnext_mult=convnext_mult,
        )

        self.middle = Middle(
            mid_dim=mid_dim,
            time_dim=time_dim,
            resnet_block_groups=resnet_block_groups,
            use_convnext=use_convnext,
            convnext_mult=convnext_mult,
        )

        self.decoder = Decoder(
            dim=dim,
            init_dim=init_dim,
            out_dim=out_dim,
            time_dim=time_dim,
            dim_mults=dim_mults,
            channels=channels,
            resnet_block_groups=resnet_block_groups,
            residual=residual,
            use_convnext=use_convnext,
            convnext_mult=convnext_mult,
        )

    @jaxtyped
    @beartype
    def forward(
        self,
        x: Float[torch.Tensor, 'b c h w'],
        time: Int[torch.Tensor, 'b'],
    ) -> Float[torch.Tensor, 'b c h w']:
        t = self.time_mlp(time) if exists(self.time_mlp) else None

        x, h = self.encoder(x, t)

        x = self.middle(x, t)

        x = self.decoder(x, h=h, t=t)

        return x
