from utils.module import *


class AdaptiveFeatureModulation(nn.Module):
    def __init__(
        self,
    ):
        super().__init__()

        self.model_channels = 320
        self.out_channels = 4
        flow_dim_scale = 1
        
        # Example: adaptive feature modulation in a single scale
        is_same_channel = True
        out_channels_1 = 320
        self.flow_cond_norm = FloatGroupNorm(32, out_channels_1)

        if is_same_channel:
            flow_in_channel = out_channels_1 // flow_dim_scale
        else:
            flow_in_channel = out_channels_1 // flow_dim_scale // 2
        
        self.flow_gamma_spatial = nn.Conv2d(flow_in_channel, out_channels_1 // 4, 3, padding=1)
        self.flow_gamma_temporal = zero_module(nn.Conv1d(out_channels_1 // 4, out_channels_1, kernel_size=3,
                                            stride=1,
                                            padding=1,
                                            padding_mode='replicate'))
        self.flow_beta_spatial = nn.Conv2d(flow_in_channel, out_channels_1 // 4, 3, padding=1)
        self.flow_beta_temporal = zero_module(nn.Conv1d(out_channels_1 // 4, out_channels_1, kernel_size=3,
                                            stride=1,
                                            padding=1,
                                            padding_mode='replicate'))

    def forward(self, hidden_state, flow, num_video_frames=16):
        '''
        Input args:
            hidden_state: SVD UNet hidden state features, shape: [L,C,H,W]
            flow: features from Motion_encoder, shape: [L,C,H,W]

        Return:
            hidden_state: updated hidden_state features, shape: [L,C,H,W]
        '''
        assert flow is not None, 'You must provide the flow to the ResBlockEmbed'
        gamma_flow = self.flow_gamma_spatial(flow)
        beta_flow = self.flow_beta_spatial(flow)
        _, _, hh, wh = beta_flow.shape
        gamma_flow = rearrange(gamma_flow, "(b f) c h w -> (b h w) c f", f=num_video_frames)
        beta_flow = rearrange(beta_flow, "(b f) c h w -> (b h w) c f", f=num_video_frames)
        gamma_flow = self.flow_gamma_temporal(gamma_flow)
        beta_flow = self.flow_beta_temporal(beta_flow)
        gamma_flow = rearrange(gamma_flow, "(b h w) c f -> (b f) c h w", h=hh, w=wh)
        beta_flow = rearrange(beta_flow, "(b h w) c f -> (b f) c h w", h=hh, w=wh)
        hidden_state = hidden_state + self.flow_cond_norm(hidden_state) * gamma_flow + beta_flow
        
        
        return hidden_state
    
