import math
from copy import deepcopy
from functools import partial
from inspect import isfunction
from typing import Callable, List, Optional, Sequence, Tuple, Union

import torch
import torch.nn.functional as F
from beartype import beartype
from einops import rearrange, reduce
from einops.layers.torch import Rearrange
from jaxtyping import Float, Int, jaxtyped
from torch import einsum, nn


def exists(x):
    return x is not None


# region Network helpers
def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d


class Residual(nn.Module):
    def __init__(self, fn: Callable):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x


def Upsample(dim: int, dim_out: int = None):
    return nn.Sequential(
        nn.Upsample(scale_factor=2, mode='nearest'),
        nn.Conv2d(dim, default(dim_out, dim), 3, padding=1),
    )


def Downsample(dim: int, dim_out: int = None):
    # No More Strided Convolutions or Pooling
    return nn.Sequential(
        Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),
        nn.Conv2d(dim * 4, default(dim_out, dim), 1),
    )


def UpsampleT(dim):
    return nn.ConvTranspose2d(dim, dim, 4, 2, 1)


def DownsampleC(dim):
    return nn.Conv2d(dim, dim, 4, 2, 1)
# endregion


# region Position embeddings
class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.dim = dim

    @jaxtyped
    @beartype
    def forward(self, x: Int[torch.Tensor, 'b *other']) -> Float[torch.Tensor, 'b *other dim']:
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10_000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[..., None] * emb[None, ...]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb
# endregion


# region ResNet block / ConvNeXT block
class WeightStandardizedConv2d(nn.Conv2d):
    """
    https://arxiv.org/abs/1903.10520
    weight standardization purportedly works synergistically with group normalization
    """
    @jaxtyped
    @beartype
    def forward(self, x: Float[torch.Tensor, 'b c h w']) -> Float[torch.Tensor, 'b c1 h1 w1']:
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3

        weight = self.weight
        mean = reduce(weight, "o ... -> o 1 1 1", "mean")
        # noinspection PyTypeChecker
        var = reduce(weight, "o ... -> o 1 1 1", partial(torch.var, unbiased=False))
        normalized_weight = (weight - mean) * (var + eps).rsqrt()

        return F.conv2d(
            x,
            normalized_weight,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups,
        )


class Block(nn.Module):
    def __init__(self, dim: int, dim_out: int, groups: int = 8):
        super().__init__()
        self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding=1)
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    @jaxtyped
    @beartype
    def forward(
        self,
        x: Float[torch.Tensor, 'b c h w'],
        scale_shift: Optional[Tuple[
            Float[torch.Tensor, 'b dim 1 1'],
            Float[torch.Tensor, 'b dim 1 1']
        ]] = None,
    ) -> Float[torch.Tensor, 'b c1 h1 w1']:
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x


class ResnetBlock(nn.Module):
    """https://arxiv.org/abs/1512.03385"""

    def __init__(
        self,
        dim: int,
        dim_out: int,
        *,
        time_emb_dim: int = None,
        groups: int = 8,
        multi_times: bool = False,
        n_dom: int = None,
    ):
        super().__init__()
        if exists(time_emb_dim):
            time_emb_dim = time_emb_dim if not multi_times else time_emb_dim * n_dom
        self.mlp = (
            nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2)) if exists(time_emb_dim) else None
        )

        self.block1 = Block(dim, dim_out, groups=groups)
        self.block2 = Block(dim_out, dim_out, groups=groups)
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    @jaxtyped
    @beartype
    def forward(
        self,
        x: Float[torch.Tensor, 'b c h w'],
        time_emb: Float[torch.Tensor, 'b emb_dim'] = None,
    ) -> Float[torch.Tensor, 'b c1 h1 w1']:

        scale_shift = None

        if exists(self.mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb)
            time_emb = rearrange(time_emb, "b c -> b c 1 1")
            scale_shift = time_emb.chunk(2, dim=1)

        h = self.block1(x, scale_shift=scale_shift)
        h = self.block2(h)
        return h + self.res_conv(x)


class ConvNextBlock(nn.Module):
    """https://arxiv.org/abs/2201.03545"""

    def __init__(self, dim: int, dim_out: int, *, time_emb_dim: bool = None, mult: int = 2, norm: bool = True):
        super().__init__()
        self.mlp = (
            nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim))
            if exists(time_emb_dim)
            else None
        )

        self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)

        self.net = nn.Sequential(
            nn.GroupNorm(1, dim) if norm else nn.Identity(),
            nn.Conv2d(dim, dim_out * mult, 3, padding=1),
            nn.GELU(),
            nn.GroupNorm(1, dim_out * mult),
            nn.Conv2d(dim_out * mult, dim_out, 3, padding=1),
        )

        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    @jaxtyped
    @beartype
    def forward(self, x: Float[torch.Tensor, 'b c h w'], time_emb: Float[torch.Tensor, 'b dim'] = None):
        h = self.ds_conv(x)

        if exists(self.mlp) and exists(time_emb):
            assert exists(time_emb), "time embedding must be passed in"
            condition = self.mlp(time_emb)
            h = h + rearrange(condition, "b c -> b c 1 1")

        h = self.net(h)
        return h + self.res_conv(x)
# endregion


# region Attention module
class Attention(nn.Module):
    def __init__(self, dim: int, heads: int = 4, dim_head: int = 32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    @jaxtyped
    @beartype
    def forward(self, x: Float[torch.Tensor, 'b c h w']) -> Float[torch.Tensor, 'b c1 h1 w1']:
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )
        q = q * self.scale

        sim = einsum("b h d i, b h d j -> b h i j", q, k)
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1)

        out = einsum("b h i j, b h d j -> b h i d", attn, v)
        out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
        return self.to_out(out)


class LinearAttention(nn.Module):
    def __init__(self, dim: int, heads: int = 4, dim_head: int = 32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)

        self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1),
                                    nn.GroupNorm(1, dim))

    @jaxtyped
    @beartype
    def forward(self, x: Float[torch.Tensor, 'b c h w']) -> Float[torch.Tensor, 'b c1 h1 w1']:
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h=self.heads), qkv
        )

        q = q.softmax(dim=-2)
        k = k.softmax(dim=-1)

        q = q * self.scale
        context = torch.einsum('b h d n, b h e n -> b h d e', k, v)

        out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
        out = rearrange(out, 'b h c (x y) -> b (h c) x y', h=self.heads, x=h, y=w)
        return self.to_out(out)
# endregion


# region Group normalization
class PreNorm(nn.Module):
    def __init__(self, dim: int, fn: Callable):
        super().__init__()
        self.fn = fn
        self.norm = nn.GroupNorm(1, dim)

    @jaxtyped
    @beartype
    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)
# endregion


class Encoder(nn.Module):
    def __init__(
        self,
        dim: int,
        init_dim: int,
        time_dim: int,
        dim_mults: Sequence[int],
        channels: int,
        resnet_block_groups: int,
        use_convnext: bool,
        convnext_mult: int,

        attention_per_block: Sequence[bool] = None,
        time_embedding_per_block: Sequence[bool] = None,
        use_double_skip: bool = True,
    ):
        super().__init__()
        attention_per_block = default(attention_per_block, [True] * len(dim_mults))
        time_embedding_per_block = default(time_embedding_per_block, [True] * len(dim_mults))
        self.use_double_skip = use_double_skip

        self.init_conv = nn.Conv2d(channels, init_dim, 1, padding=0)

        dims = [init_dim] + [dim * m for m in dim_mults]
        in_out = list(zip(dims[:-1], dims[1:]))

        if use_convnext:
            block_klass = partial(ConvNextBlock, mult=convnext_mult)
        else:
            block_klass = partial(ResnetBlock, groups=resnet_block_groups)

        self.downs = nn.ModuleList([])
        num_resolutions = len(in_out)
        assert num_resolutions == len(attention_per_block) == len(time_embedding_per_block)
        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            use_attention = attention_per_block[ind]
            use_time_emb = time_embedding_per_block[ind]
            attention_layer = LinearAttention(dim_out) if use_attention else nn.Identity()
            time_dim_layer = time_dim if use_time_emb else None

            self.downs.append(
                nn.ModuleList(
                    [
                        block_klass(dim_in, dim_out, time_emb_dim=time_dim_layer),
                        block_klass(dim_out, dim_out, time_emb_dim=time_dim_layer),
                        Residual(PreNorm(dim_out, attention_layer)),
                        Downsample(dim=dim_out, dim_out=dim_out) if not is_last else nn.Conv2d(dim_out, dim_out, 3, padding=1),
                    ]
                )
            )

    @jaxtyped
    @beartype
    def forward(
        self,
        x: Float[torch.Tensor, 'b c h w'],
        t: Float[torch.Tensor, 'b *n_dom dim'],
    ) -> Tuple[
        Float[torch.Tensor, 'b ci hi wi'],        # embedding
        List[Float[torch.Tensor, 'b cj hj wj']],  # residual + skip connections list
    ]:
        if t is not None and t.dim() == 3:
            b, ndom, dim = t.shape
            t = t.reshape(b, ndom * dim)

        x = self.init_conv(x)
        h = [x.clone()]

        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            if self.use_double_skip:
                h.append(x)

            x = block2(x, t)
            x = attn(x)
            h.append(x)

            x = downsample(x)

        return x, h


class Decoder(nn.Module):
    def __init__(
        self,
        dim: int,
        out_dim: int,
        init_dim: int,
        time_dim: int,
        dim_mults: Sequence[int],
        channels: int,
        resnet_block_groups: int,
        residual: bool,
        use_convnext: bool,
        convnext_mult: int,

        attention_per_block: Sequence[bool] = None,
        time_embedding_per_block: Sequence[bool] = None,
        use_double_skip: bool = True,
    ):
        super().__init__()
        self.use_double_skip = use_double_skip
        attention_per_block = default(attention_per_block, [True] * len(dim_mults))
        time_embedding_per_block = default(time_embedding_per_block, [True] * len(dim_mults))

        self.residual = residual

        init_dim = default(init_dim, dim)
        out_dim = default(out_dim, channels)
        dims = [init_dim] + [dim * m for m in dim_mults]
        in_out = list(zip(dims[:-1], dims[1:]))

        if use_convnext:
            block_klass = partial(ConvNextBlock, mult=convnext_mult)
        else:
            block_klass = partial(ResnetBlock, groups=resnet_block_groups)

        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)
        assert num_resolutions == len(attention_per_block) == len(time_embedding_per_block)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (num_resolutions - 1)

            use_attention = attention_per_block[ind]
            use_time_emb = time_embedding_per_block[ind]
            attention_layer = LinearAttention(dim_in) if use_attention else nn.Identity()
            time_dim_layer = time_dim if use_time_emb else None

            self.ups.append(
                nn.ModuleList(
                    [
                        block_klass(dim=dim_out + dim_out, dim_out=dim_in, time_emb_dim=time_dim_layer),
                        block_klass(dim=dim_in + (dim_out if self.use_double_skip else 0), dim_out=dim_in, time_emb_dim=time_dim_layer),
                        Residual(PreNorm(dim_in, attention_layer)),
                        Upsample(dim=dim_in, dim_out=dim_in) if not is_last else nn.Conv2d(dim_in, dim_in, 3, padding=1),
                    ]
                )
            )

        self.final_res_block = block_klass(dim * 2 if self.residual else dim, dim, time_emb_dim=time_dim)
        self.final_conv = nn.Conv2d(dim, out_dim, 1)

    @jaxtyped
    @beartype
    def forward(
        self,
        x: Float[torch.Tensor, 'b c h w'],
        h: List[Float[torch.Tensor, 'b cl hl wl']],
        t: Float[torch.Tensor, 'b *n_dom dim'],
        control: Optional[list] = None,
    ) -> Float[torch.Tensor, 'b ci hi wi']:
        if t is not None and t.dim() == 3:
            b, ndom, dim = t.shape
            t = t.reshape(b, ndom * dim)

        r = h.pop(0)

        for block1, block2, attn, upsample in self.ups:

            if control is None:
                x = torch.cat([x, h.pop()], dim=1)
            else:
                x = torch.cat([x, h.pop() + control.pop()], dim=1)

            x = block1(x, t)

            if self.use_double_skip:
                if control is None:
                    x = torch.cat([x, h.pop()], dim=1)
                else:
                    x = torch.cat([x, h.pop() + control.pop()], dim=1)

            x = block2(x, t)
            x = attn(x)

            x = upsample(x)

        if self.residual:
            x = torch.cat([x, r], dim=1)

        x = self.final_res_block(x, t)
        x = self.final_conv(x)

        return x


class Middle(nn.Module):
    def __init__(
        self,
        mid_dim: int,
        time_dim: int,
        resnet_block_groups: int,
        use_convnext: bool,
        convnext_mult: int,

        attention: bool = None,
        time_embedding: bool = None,
        mid_dim_in: int = None,  # used for our approch
        middle_linear_attention: bool = False,
    ):
        super().__init__()
        mid_dim_in = default(mid_dim_in, mid_dim)

        attention = default(attention, True)
        time_embedding = default(time_embedding, True)
        attention_layer = (LinearAttention(mid_dim) if middle_linear_attention else Attention(mid_dim)) if attention else nn.Identity()
        time_dim = time_dim if time_embedding else None

        if use_convnext:
            block_klass = partial(ConvNextBlock, mult=convnext_mult)
        else:
            block_klass = partial(ResnetBlock, groups=resnet_block_groups)

        self.mid_block1 = block_klass(mid_dim_in, mid_dim, time_emb_dim=time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, attention_layer))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)

    @jaxtyped
    @beartype
    def forward(
        self,
        x: Float[torch.Tensor, 'b c h w'],
        t: Float[torch.Tensor, 'b *n_dom dim'],
    ) -> Float[torch.Tensor, 'b ci hi wi']:
        if t is not None and t.dim() == 3:
            b, ndom, dim = t.shape
            t = t.reshape(b, ndom * dim)

        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        return x


class SplitEncoder(nn.Module):
    def __init__(
        self,
        dim_per_dom: List[int],
        attention_per_block: Sequence[bool],
        time_embedding_per_block: Sequence[bool],
        use_double_skip: bool,

        dim: int,
        init_dim: int,
        time_dim: int,
        dim_mults: Sequence[int],
        resnet_block_groups: int,
        use_convnext: bool,
        convnext_mult: int,
    ):
        super().__init__()
        self.dim_per_dom = dim_per_dom

        self.encoders = nn.ModuleList([])
        for dim_dom in dim_per_dom:
            self.encoders.append(Encoder(
                dim=dim,
                init_dim=init_dim,
                time_dim=time_dim,
                dim_mults=dim_mults,
                channels=dim_dom,
                resnet_block_groups=resnet_block_groups,
                use_convnext=use_convnext,
                convnext_mult=convnext_mult,

                attention_per_block=attention_per_block,
                time_embedding_per_block=time_embedding_per_block,
                use_double_skip=use_double_skip,
            ))

    @jaxtyped
    @beartype
    def forward(
        self,
        x: Float[torch.Tensor, 'b c_n_dom h w'],
        t: Float[torch.Tensor, 'b *n_dom dim'],
    ) -> Tuple[
        List[Float[torch.Tensor, 'b ci hi wi']],        # embeddings, #List = n_dom
        List[List[Float[torch.Tensor, 'b cj hj wj']]],  # residual + skip connections list, #List = n_dom, #SubList = nb skips
    ]:
        if t.dim() == 2:
            n_dom = len(self.encoders)
            t = t.unsqueeze(1).repeat(1, n_dom, 1)

        x_per_dom = torch.split(tensor=x, split_size_or_sections=self.dim_per_dom, dim=1)
        t_per_dom = torch.split(tensor=t, split_size_or_sections=1, dim=1)

        z = []
        hs = []

        for encoder, x_dom, t_dom in zip(self.encoders, x_per_dom, t_per_dom):
            x_dom, h_dom = encoder(x_dom, t_dom)
            z.append(x_dom)
            hs.append(h_dom)

        return z, hs


class SplitDecoder(nn.Module):
    def __init__(
        self,
        dim_per_dom: List[int],
        attention_per_block: Sequence[bool],
        time_embedding_per_block: Sequence[bool],
        use_double_skip: bool,

        dim: int,
        init_dim: int,
        time_dim: int,
        dim_mults: Sequence[int],
        channels: int,
        resnet_block_groups: int,
        residual: bool,
        use_convnext: bool,
        convnext_mult: int,
    ):
        super().__init__()
        self.dim_per_dom = dim_per_dom

        self.decoders = nn.ModuleList([])
        for dim_dom in dim_per_dom:
            self.decoders.append(Decoder(
                dim=dim,
                out_dim=dim_dom,
                init_dim=init_dim,
                time_dim=time_dim,
                dim_mults=dim_mults,
                channels=channels,
                resnet_block_groups=resnet_block_groups,
                residual=residual,
                use_convnext=use_convnext,
                convnext_mult=convnext_mult,

                attention_per_block=attention_per_block,
                time_embedding_per_block=time_embedding_per_block,
                use_double_skip=use_double_skip,
            ))

    @jaxtyped
    @beartype
    def forward(
        self,
        z: Float[torch.Tensor, 'b c h w'],                  # embedding after bottleneck
        h: List[List[Float[torch.Tensor, 'b cj hj wj']]],  # residual + skip connections list, #List = n_dom, #SubList = nb skips
        t: Float[torch.Tensor, 'b n_dom dim'],                # time embedding
        control: Optional[list] = None,
    ) -> Float[torch.Tensor, 'b c_n_dom hi wi']:

        t_per_dom = torch.split(tensor=t, split_size_or_sections=1, dim=1)

        x_dom = []
        for i, (decoder, t_dom, h_dom, encoder_control) in enumerate(zip(self.decoders, t_per_dom, h, control)):
            _x = decoder(x=z, h=h_dom, t=t_dom, control=encoder_control)
            x_dom.append(_x)

        x = torch.cat(x_dom, dim=1)

        return x


class ToMiddle(nn.Module):
    """
    Class used when we have one encoder per domain.
    We have to prepare the Zs for the middle block
    #Zs = [z_dom1, z_dom2, ...] -> z_mid

    1) pre-pros: Zs -> pz, transform the list to a tensor
    2) to-middle: pz -> z_mid, network to further preprocess pz
    """
    def __init__(
        self,
        n_dom: int,
        bottleneck_dim: int,
        pz_strat: str,  # how the latent space should be pre-processed into pz
        z_mid_strat: str,  # how pz should be pre-processed into z_mid

        time_dim: int,
        resnet_block_groups: int,
        use_convnext: bool,
        convnext_mult: int,
        time_embedding: bool,
        nb_block_following: int,
    ):
        super().__init__()
        self.n_dom = n_dom
        self.pz_strat = pz_strat
        self.z_mid_strat = z_mid_strat

        time_dim = time_dim if time_embedding else None
        if use_convnext:
            block_klass = partial(ConvNextBlock, mult=convnext_mult)
        else:
            block_klass = partial(ResnetBlock, groups=resnet_block_groups)

        # Zs -> pz
        pz_dim = self.get_pz_dim(one_dom_dim=bottleneck_dim)

        # pz -> z_mid
        if z_mid_strat == 'identity':
            z_mid_dim = pz_dim
            self.pz_to_z_mid_network = ToMiddleIdentity()
        elif z_mid_strat == 'reduce_to_single':
            z_mid_dim = bottleneck_dim
            self.pz_to_z_mid_network = block_klass(pz_dim, bottleneck_dim, time_emb_dim=time_dim)
        else:
            raise NotImplementedError
        self.z_mid_dim = z_mid_dim

        block_following = []
        for _ in range(nb_block_following):
            block_following.append(
                nn.ModuleList([
                    block_klass(bottleneck_dim, bottleneck_dim, time_emb_dim=time_dim),
                    Residual(PreNorm(bottleneck_dim, Attention(bottleneck_dim))),
                    block_klass(bottleneck_dim, bottleneck_dim, time_emb_dim=time_dim),
                ])
            )
        self.block_following = nn.ModuleList(block_following)

    @jaxtyped
    @beartype
    def forward(
        self,
        encoded_x: List[Float[torch.Tensor, 'b ci hi wi']],  # embeddings, #List = n_dom
        t: Float[torch.Tensor, 'b *n_dom dim'],
    ) -> Float[torch.Tensor, 'b c_mid h_mid w_mid']:
        if t is not None and t.dim() == 3:
            b, ndom, dim = t.shape
            t = t.reshape(b, ndom * dim)

        # Zs -> pz
        pz = self.zs_to_pz(encoded_x)

        # pz -> z_mid
        z_mid = self.pz_to_z_mid_network(pz, t)

        for block1, attn, block2 in self.block_following:
            z_mid = block1(z_mid, t)
            z_mid = attn(z_mid)
            z_mid = block2(z_mid, t)

        return z_mid

    @beartype
    def get_pz_dim(self, one_dom_dim: int) -> int:
        if self.pz_strat == 'cat':
            return one_dom_dim * self.n_dom
        elif self.pz_strat == 'sum':
            return one_dom_dim
        elif self.pz_strat == 'mean':
            return one_dom_dim
        elif self.pz_strat in ['m0', 'm1', 'm2']:  # case when only using first domain for testing purpose
            return one_dom_dim
        elif self.pz_strat in ['m01', 'm02']:
            return one_dom_dim * 2
        else:
            raise NotImplementedError

    @beartype
    def zs_to_pz(
        self,
        encoded_x: List[torch.Tensor],
    ) -> torch.Tensor:
        if self.pz_strat == 'cat':
            x = torch.cat(encoded_x, dim=1)
        elif self.pz_strat == 'sum':
            x = torch.stack(encoded_x, dim=1).sum(dim=1)
        elif self.pz_strat == 'mean':
            x = torch.stack(encoded_x, dim=1).mean(dim=1)
        elif self.pz_strat == 'm0':
            x = encoded_x[0]
        elif self.pz_strat == 'm1':
            x = encoded_x[1]
        elif self.pz_strat == 'm2':
            x = encoded_x[2]
        elif self.pz_strat == 'm01':
            x = torch.cat(encoded_x[:2], dim=1)
        elif self.pz_strat == 'm02':
            x = torch.cat([encoded_x[i] for i in [0, 2]], dim=1)
        else:
            raise NotImplementedError
        return x


class ToMiddleIdentity(nn.Module):
    @jaxtyped
    @beartype
    def forward(
        self,
        encoded_x: Float[torch.Tensor, 'b ci hi wi'],  # embeddings, #List = n_dom
        t: Float[torch.Tensor, 'b *n_dom dim'],
    ) -> Float[torch.Tensor, 'b c_mid h_mid w_mid']:
        return encoded_x


class SkipConnectionAdaptation(nn.Module):
    """
    Adapt the skip connections from the encoder to the decoder,
    it's useful when we have multiple encoder, and one single decoder, then the different skip connection need to be
    handled.

    One Encoder, One Decoder -> id
    One Encoder, Multiple Decoder -> id duplicated
    Multiple Encoder, Multiple Decoder -> sum | mean | cat | id
    Multiple Encoder, One Decoder -> sum | mean | cat: not implemented
    """
    def __init__(
        self,
        nb_dom: int,
        multiple_encoders: bool,
        multiple_decoders: bool,
    ):
        super().__init__()
        self.nb_dom = nb_dom
        self.multiple_encoders = multiple_encoders
        self.multiple_decoders = multiple_decoders

    @jaxtyped
    @beartype
    def forward(
        self,
        hs,
    ):
        if self.multiple_encoders and self.multiple_decoders:
            return hs
        elif self.multiple_encoders and not self.multiple_decoders:
            raise NotImplementedError()
        elif not self.multiple_encoders and self.multiple_decoders:
            return [[i for i in hs] for _ in range(self.nb_dom)]
        elif not self.multiple_encoders and not self.multiple_decoders:
            return hs
