from copy import deepcopy
from typing import Sequence, Optional
import torch
from torch import nn

from jaxtyping import Float, Int, jaxtyped
from beartype import beartype

from src.Backbones.utils_backbone_unet import (
    SinusoidalPosEmb,
    exists,
    Middle, SplitEncoder, ToMiddle, default, Encoder,
)
from utils.utils import display_tensor


def conv_nd(*args, **kwargs):
    return nn.Conv2d(*args, **kwargs)


def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module


class Unet_Cold_Multi_Domain_ControlNet(nn.Module):
    def __init__(
        self,
        dim_per_dom: list[int],
        encoder_split: bool,
        encoder_attention_per_block: Optional[Sequence[bool]],
        encoder_time_embedding_per_block: Optional[Sequence[bool]],

        pz_strat: str,
        z_mid_strat: str,

        middle_linear_attention: bool,
        middle_attention: bool,
        middle_time_embedding: bool,
        middle_nb_block_following: int,

        decoder_split: bool,
        decoder_attention_per_block: Optional[Sequence[bool]],
        decoder_time_embedding_per_block: Optional[Sequence[bool]],

        use_double_skip: bool,

        dim: int,
        time_dim: int,
        init_dim: int,
        out_dim: int,
        dim_mults: Sequence[int],
        channels: int,
        resnet_block_groups: int,
        with_time_emb: bool,
        use_convnext: bool,
        convnext_mult: int,
        residual: bool,
    ):
        super().__init__()

        full_dim = sum(dim_per_dom)
        mid_dim = dim * dim_mults[-1]

        encoder_attention_per_block = default(encoder_attention_per_block, [True] * len(dim_mults))
        encoder_time_embedding_per_block = default(encoder_time_embedding_per_block, [True] * len(dim_mults))

        single_time_dim = time_dim
        full_time_dim = single_time_dim * len(dim_per_dom) if exists(single_time_dim) else None
        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(dim),
            nn.Linear(dim, single_time_dim),
            nn.GELU(),
            nn.Linear(single_time_dim, single_time_dim),
        ) if with_time_emb else None

        self.input_hint_block = nn.Sequential(
            zero_module(conv_nd(full_dim, dim, 3, padding=1)),
        )

        if not encoder_split:
            raise NotImplementedError
        self.encoders = SplitEncoder(
            dim_per_dom=dim_per_dom,

            dim=dim,
            init_dim=init_dim,
            time_dim=single_time_dim,
            dim_mults=dim_mults,
            resnet_block_groups=resnet_block_groups,
            use_convnext=use_convnext,
            convnext_mult=convnext_mult,

            attention_per_block=encoder_attention_per_block,
            time_embedding_per_block=encoder_time_embedding_per_block,
            use_double_skip=use_double_skip,
        )

        if not encoder_split:
            raise NotImplementedError
        self.to_middle = ToMiddle(
            n_dom=len(dim_per_dom),
            bottleneck_dim=mid_dim,
            pz_strat=pz_strat,
            z_mid_strat=z_mid_strat,
            time_dim=full_time_dim,
            resnet_block_groups=resnet_block_groups,
            use_convnext=use_convnext,
            convnext_mult=convnext_mult,
            time_embedding=middle_time_embedding,
            nb_block_following=middle_nb_block_following,
        )
        z_mid_dim = self.to_middle.z_mid_dim  # output dim of to_middle network

        self.middle = Middle(
            mid_dim_in=z_mid_dim,
            mid_dim=mid_dim,
            time_dim=full_time_dim,
            resnet_block_groups=resnet_block_groups,

            use_convnext=use_convnext,
            convnext_mult=convnext_mult,
            attention=middle_attention,
            time_embedding=middle_time_embedding,
            middle_linear_attention=middle_linear_attention,
        )
        # there is a need to have 1 zero conv per encoder
        # and one zero global conv for the middle block
        self.zero_convs = nn.ModuleList([])

        for enc in self.encoders.encoders[0].downs:
            self.zero_convs.append(self.make_zero_conv(enc[0].net[-1].out_channels))
            self.zero_convs.append(self.make_zero_conv(enc[0].net[-1].out_channels))
        self.zero_convs = nn.ModuleList([deepcopy(self.zero_convs)] * len(dim_per_dom))

        self.middle_block_out_zero = self.make_zero_conv(z_mid_dim)

    @jaxtyped(typechecker=beartype)
    def forward(
        self,
        data_cn: Float[torch.Tensor, 'b c_n_dom h w'],
        condition_cn: Float[torch.Tensor, 'b c_n_dom h w'],
        time: Int[torch.Tensor, 'b n_dom'],
    ):
        """
        Forward function of the ControlNet part, which is just encoder + middle block
        Each activation goes through the 0 conv layer, then is stored into outs
        """
        t_emb = self.time_mlp(time) if exists(self.time_mlp) else None

        # input need to go through a zero skip connection and adaptation
        guided_hint = self.input_hint_block(condition_cn)

        outs = []
        zs, skip_controls = self.all_encoders_forward(
            x=data_cn,
            t=t_emb,
            guided_hint=guided_hint,
        )
        outs += skip_controls

        z_middle = self.to_middle(zs, t_emb)
        z = self.middle(z_middle, t_emb)
        outs.append(self.middle_block_out_zero(z))

        return outs

    @jaxtyped(typechecker=beartype)
    def all_encoders_forward(
        self,
        x: Float[torch.Tensor, 'b c_n_dom h w'],
        t: Float[torch.Tensor, 'b *n_dom dim'],
        guided_hint: Float[torch.Tensor, 'b c_context h w'],
    ) -> tuple[
        list[Float[torch.Tensor, 'b ci hi wi']],        # embeddings, #list = n_dom
        list[list[Float[torch.Tensor, 'b cj hj wj']]],  # residual + skip connections list, #list = n_dom, #SubList = nb skips
    ]:
        """
        Forward for all the encoder at the same time. Return a list of list, where inner list are the controls for the encoder_i
        """
        encoders_objet = self.encoders
        if t.dim() == 2:
            n_dom = len(encoders_objet.encoders)
            t = t.unsqueeze(1).repeat(1, n_dom, 1)

        x_per_dom = torch.split(tensor=x, split_size_or_sections=encoders_objet.dim_per_dom, dim=1)
        t_per_dom = torch.split(tensor=t, split_size_or_sections=1, dim=1)

        z = []
        encoder_controls = []  # change the skips connections to be the controls who go through the zero conv module

        for encoder, zero_convs, x_dom, t_dom in zip(encoders_objet.encoders, self.zero_convs, x_per_dom, t_per_dom):
            x_dom, h_dom, outs_control_dom_i = self.unique_encoder_forward(
                enc=encoder,
                zero_convs=zero_convs,
                x=x_dom,
                t=t_dom,
                guided_hint=guided_hint,
            )
            z.append(x_dom)
            encoder_controls.append(outs_control_dom_i)

        return z, encoder_controls

    @jaxtyped(typechecker=beartype)
    def unique_encoder_forward(
        self,
        enc: Encoder,
        zero_convs: nn.ModuleList,
        x: Float[torch.Tensor, 'b c h w'],
        t: Float[torch.Tensor, 'b *n_dom dim'],
        guided_hint: Float[torch.Tensor, 'b ck hk wk'],
    ) -> tuple[
        Float[torch.Tensor, 'b ci hi wi'],        # embedding
        list[Float[torch.Tensor, 'b cj hj wj']],  # residual + skip connections list
        list[Float[torch.Tensor, 'b cl hl wl']],  # it's the control data
    ]:
        """
        forward for one domain, full encoding process
        """
        if t is not None and t.dim() == 3:
            b, ndom, dim = t.shape
            t = t.reshape(b, ndom * dim)

        x = enc.init_conv(x)
        x += guided_hint

        h = [x.clone()]

        outs_control = []

        for i, (block1, block2, attn, downsample) in enumerate(enc.downs):
            x = block1(x, t)
            if enc.use_double_skip:
                h.append(x)
            outs_control.append(zero_convs[i*2](x))

            x = block2(x, t)
            x = attn(x)
            h.append(x)
            outs_control.append(zero_convs[i*2+1](x))

            x = downsample(x)

        return x, h, outs_control

    def make_zero_conv(self, channels):
        return zero_module(conv_nd(channels, channels, 1, padding=0))


# class BaseModel(Unet_Cold_Multi_Domain):
#     raise NotImplementedError
    # not required for us actually, it's in Unet_Cold_Multi_Domain directly


class MasterControlNet(nn.Module):
    def __init__(
        self,
        diffusion_model,
        control_model,
    ):
        super().__init__()
        self.diffusion_model = diffusion_model
        self.control_model = control_model

    def forward(
        self,
        data_cn: Float[torch.Tensor, 'b c_n_dom h w'],
        condition_cn: Float[torch.Tensor, 'b c_n_dom h w'],
        time: Int[torch.Tensor, 'b n_dom'],
    ):
        control = self.control_model(
            data_cn=data_cn,
            condition_cn=condition_cn,
            time=time,
        )
        eps = self.diffusion_model(
            x=data_cn,
            time=time,
            control=control,
        )
        return eps
