import torch.nn as nn
from einops import rearrange
from utils.module import *


class MotionEncoder(nn.Module):
    def __init__(
        self,
    ):
        super().__init__()

        self.model_channels = 320
        self.out_channels = 4
        flow_dim_scale = 1
        channel_mult = [1,2,4,4]
        
        self.flow_blocks = nn.ModuleList([])
        flow_in_block = TimestepEmbedSequential(
            nn.Conv2d(3, self.model_channels // flow_dim_scale // 4, 3, stride=2, padding=1),  # [flow, motion_mask]
            nn.Conv1d(self.model_channels // flow_dim_scale // 4, self.model_channels // flow_dim_scale // 4, kernel_size=3, stride=1,
                      padding=1, padding_mode='replicate'),
            FloatGroupNorm(8, self.model_channels // flow_dim_scale // 4),
            nn.SiLU(),
            nn.Conv2d(self.model_channels // flow_dim_scale // 4, self.model_channels // flow_dim_scale // 2, 3, stride=2, padding=1),
            nn.Conv1d(self.model_channels // flow_dim_scale // 2, self.model_channels // flow_dim_scale // 2, kernel_size=3, stride=1,
                      padding=1, padding_mode='replicate'),
            FloatGroupNorm(8, self.model_channels // flow_dim_scale // 2),
            nn.SiLU(),
            nn.Conv2d(self.model_channels // flow_dim_scale // 2, self.model_channels // flow_dim_scale, 3, stride=2, padding=1),
            nn.Conv1d(self.model_channels // flow_dim_scale, self.model_channels // flow_dim_scale, kernel_size=3, stride=1,
                      padding=1, padding_mode='replicate'),
        )
        self.flow_blocks.append(flow_in_block)

        flow_in_channel = self.model_channels // flow_dim_scale
        for i_f, ch_f in enumerate(channel_mult[1:]):
            layers_f = nn.ModuleList([
                FloatGroupNorm(8, flow_in_channel),
                nn.SiLU(),
                nn.Conv2d(flow_in_channel, ch_f * self.model_channels // flow_dim_scale, 3, padding=1),
                nn.Conv1d(ch_f * self.model_channels // flow_dim_scale, ch_f * self.model_channels // flow_dim_scale, kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
            ])
            flow_in_channel = ch_f * self.model_channels // flow_dim_scale
            if i_f != len(channel_mult) - 1:
                layers_f.append(
                    Downsample(
                        flow_in_channel, True, dims=2, out_channels=flow_in_channel
                    )
                )
            self.flow_blocks.append(TimestepEmbedSequential(*layers_f))

    def forward(self, flow, num_video_frames):
        
        '''
        Input args: 
            flow: Channel-wise concatenation of region-wise trajectories and motion_mask, shape: [L,C,H,W]
            num_video_frames : Set L=16 as default

        Return:
            hs_z_flow: Extract multi-level motion features, utilized for UNet Downsampler layers
            hs_z_flow_clone: Extract multi-level motion features, utilized for UNet Upsampler layers
        '''
        
        # process Input motion condition to get motion features
        hs_z_flow = []
        hs_z_flow_clone = []
        flow  = rearrange(flow, "b l c h w -> (b l) c h w")
        for module in self.flow_blocks:
            
            # [L,3,320,576]-->[L,320,40,72]-->[L,640,20,36]-->[L,1280,10,18]-->[L,1280,5,9]
            flow = module(flow, emb=None, num_video_frames=num_video_frames) 
            hs_z_flow.extend([flow])
            hs_z_flow_clone.extend([flow.clone()])
        
        return hs_z_flow, hs_z_flow_clone

