from typing import List, Sequence, Optional
import torch
from torch import nn

from jaxtyping import Float, Int, jaxtyped
from beartype import beartype

from src.Backbones.utils_backbone_unet import (
    SinusoidalPosEmb,
    exists,
    Encoder, Middle, Decoder, SplitEncoder, ToMiddle, SkipConnectionAdaptation, SplitDecoder, ToMiddleIdentity, default,
)


class Unet_Cold_Multi_Domain(nn.Module):
    def __init__(
        self,
        dim_per_dom: List[int],
        encoder_split: bool,
        encoder_attention_per_block: Optional[Sequence[bool]],
        encoder_time_embedding_per_block: Optional[Sequence[bool]],

        pz_strat: str,
        z_mid_strat: str,

        middle_linear_attention: bool,
        middle_attention: bool,
        middle_time_embedding: bool,
        middle_nb_block_following: int,

        decoder_split: bool,
        decoder_attention_per_block: Optional[Sequence[bool]],
        decoder_time_embedding_per_block: Optional[Sequence[bool]],

        use_double_skip: bool,

        dim: int,
        time_dim: int,
        init_dim: int,
        out_dim: int,
        dim_mults: Sequence[int],
        channels: int,
        resnet_block_groups: int,
        with_time_emb: bool,
        use_convnext: bool,
        convnext_mult: int,
        residual: bool,
    ):
        super().__init__()

        full_dim = sum(dim_per_dom)
        mid_dim = dim * dim_mults[-1]

        encoder_attention_per_block = default(encoder_attention_per_block, [True] * len(dim_mults))
        encoder_time_embedding_per_block = default(encoder_time_embedding_per_block, [True] * len(dim_mults))
        decoder_attention_per_block = default(decoder_attention_per_block, [True] * len(dim_mults))
        decoder_time_embedding_per_block = default(decoder_time_embedding_per_block, [True] * len(dim_mults))

        single_time_dim = time_dim
        full_time_dim = single_time_dim * len(dim_per_dom) if exists(single_time_dim) else None
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(dim),
            nn.Linear(dim, single_time_dim),
            nn.GELU(),
            nn.Linear(single_time_dim, single_time_dim),
        ) if with_time_emb else None

        if encoder_split:
            self.encoders = SplitEncoder(
                dim_per_dom=dim_per_dom,

                dim=dim,
                init_dim=init_dim,
                time_dim=single_time_dim,
                dim_mults=dim_mults,
                resnet_block_groups=resnet_block_groups,
                use_convnext=use_convnext,
                convnext_mult=convnext_mult,

                attention_per_block=encoder_attention_per_block,
                time_embedding_per_block=encoder_time_embedding_per_block,
                use_double_skip=use_double_skip,
            )
        else:
            self.encoders = Encoder(
                dim=dim,
                init_dim=init_dim,
                time_dim=full_time_dim,
                dim_mults=dim_mults,
                channels=full_dim,
                resnet_block_groups=resnet_block_groups,
                use_convnext=use_convnext,
                convnext_mult=convnext_mult,

                attention_per_block=encoder_attention_per_block,
                time_embedding_per_block=encoder_time_embedding_per_block,
                use_double_skip=use_double_skip,
            )

        if encoder_split:
            self.to_middle = ToMiddle(
                n_dom=len(dim_per_dom),
                bottleneck_dim=mid_dim,
                pz_strat=pz_strat,
                z_mid_strat=z_mid_strat,
                time_dim=full_time_dim,
                resnet_block_groups=resnet_block_groups,
                use_convnext=use_convnext,
                convnext_mult=convnext_mult,
                time_embedding=middle_time_embedding,
                nb_block_following=middle_nb_block_following,
            )
            z_mid_dim = self.to_middle.z_mid_dim  # output dim of to_middle network
        else:
            self.to_middle = ToMiddleIdentity()
            z_mid_dim = mid_dim

        self.middle = Middle(
            mid_dim_in=z_mid_dim,
            mid_dim=mid_dim,
            time_dim=full_time_dim,
            resnet_block_groups=resnet_block_groups,

            use_convnext=use_convnext,
            convnext_mult=convnext_mult,
            attention=middle_attention,
            time_embedding=middle_time_embedding,
            middle_linear_attention=middle_linear_attention,
        )

        self.skip_connection_adaptation = SkipConnectionAdaptation(
            nb_dom=len(dim_per_dom),
            multiple_encoders=encoder_split,
            multiple_decoders=decoder_split,
        )

        if decoder_split:
            self.decoders = SplitDecoder(
                dim_per_dom=dim_per_dom,
                attention_per_block=decoder_attention_per_block,
                time_embedding_per_block=decoder_time_embedding_per_block,
                dim=dim,
                init_dim=init_dim,
                time_dim=single_time_dim,
                dim_mults=dim_mults,
                channels=channels,
                resnet_block_groups=resnet_block_groups,
                residual=residual,
                use_convnext=use_convnext,
                convnext_mult=convnext_mult,
                use_double_skip=use_double_skip,
            )
        else:
            self.decoders = Decoder(
                dim=dim,
                out_dim=out_dim,
                init_dim=init_dim,
                time_dim=full_time_dim,
                dim_mults=dim_mults,
                channels=channels,
                resnet_block_groups=resnet_block_groups,
                residual=residual,
                use_convnext=use_convnext,
                convnext_mult=convnext_mult,
                attention_per_block=decoder_attention_per_block,
                time_embedding_per_block=decoder_time_embedding_per_block,
                use_double_skip=use_double_skip,
            )

    @jaxtyped
    @beartype
    def forward(
        self,
        x: Float[torch.Tensor, 'b c_n_dom h w'],
        time: Int[torch.Tensor, 'b n_dom'],
        control: Optional[list] = None,  # ControlNet features
    ):
        t = self.time_mlp(time) if exists(self.time_mlp) else None
        # t shape is [batch n_dom time_dim]

        zs, hs = self.encoders(x, t)

        z_middle = self.to_middle(zs, t)

        z = self.middle(z_middle, t)

        if control is not None:  # ControlNet
            z += control.pop()

        h_adapted = self.skip_connection_adaptation(hs)

        y = self.decoders(z, h=h_adapted, t=t, control=control)

        return y
