from typing import Tuple

import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.modeling_utils import ModelMixin
from utils.module_utils import zero_module
from ..resnet import InflatedConv3d


'''
VKpsGuider is designed as a keypoint encoder for motion generation where keypoints can guide the generation process.
Below are key aspects and reasons behind its construction:

    conditioning_embedding_channels: generate embeddings tailored to a specific dimensionality required by the downstream tasks.

    Sequential Processing with Convolutional Blocks: 
        a series of convolutional blocks (InflatedConv3d): incrementally increase the depth of features while reducing spatial dimensions. 
        InflatedConv3d: convolutional layer inflates 2D kernels into 3D, when adapting 2D models to process spatiotemporal data (e.g., video frames with keypoints). 
            It enables the model to capture both spatial and temporal dependencies in the keypoint sequences.
        Activation Function: SiLU (Sigmoid-weighted Linear Unit).
        Zero-Initialized Final Layer (Optional Regularization): The conv_out layer is wrapped with zero_module, where initializing certain weights to zero can help stabilize training.
        Sequential Downscaling: progressively downsamples the input through strided convolutions in its blocks, reduces spatial dimensions while increasing the receptive field of the network, enabling it to learn more abstract, high-level representations from the keypoint data.
'''

class VKpsGuider(ModelMixin):
    def __init__(
            self,
            conditioning_embedding_channels: int,
            conditioning_channels: int = 3,
            block_out_channels: Tuple[int] = (16, 32, 64, 128),
    ):
        super().__init__()
        self.conv_in = InflatedConv3d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)

        self.blocks = nn.ModuleList([])

        for i in range(len(block_out_channels) - 1):
            channel_in = block_out_channels[i]
            channel_out = block_out_channels[i + 1]
            self.blocks.append(InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1))
            self.blocks.append(InflatedConv3d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))

        self.conv_out = zero_module(InflatedConv3d(
            block_out_channels[-1],
            conditioning_embedding_channels,
            kernel_size=3,
            padding=1,
        ))

    def forward(self, conditioning):
        embedding = self.conv_in(conditioning)
        embedding = F.silu(embedding)

        for block in self.blocks:
            embedding = block(embedding)
            embedding = F.silu(embedding)

        embedding = self.conv_out(embedding)

        return embedding
