import torch
from torch import nn
import torch.nn.functional as F
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from einops import rearrange, repeat, einsum
import math
import random
from . import global_utils

class LatentRectify(nn.Module):
    def __init__(
            self,
            text_embed_dim = 768,
            visual_embed_channel_dim = 4,
            visual_embed_spatial_dim = 32,
            visual_embed_temporal_dim = 16,
            inner_dim = 256,
        ):
        super().__init__()
        self.time_proj = Timesteps(256, flip_sin_to_cos=True, downscale_freq_shift=0.0)
        self.time_embedding = TimestepEmbedding(256, text_embed_dim)
        self.nonlinearity = lambda x: F.silu(x)
        self.norm1 = nn.LayerNorm(text_embed_dim)

        self.visual_to_q = nn.Linear((visual_embed_temporal_dim - 1) * (visual_embed_spatial_dim ** 2), inner_dim)
        self.visual_to_k = nn.Linear(text_embed_dim, inner_dim)
        self.visual_to_v = nn.Linear(text_embed_dim, inner_dim)
        self.visual_to_out = nn.Linear(inner_dim, inner_dim)

        self.text_to_q = nn.Linear(text_embed_dim, inner_dim)
        self.text_to_k = nn.Linear(visual_embed_channel_dim, inner_dim)
        self.text_to_v = nn.Linear((visual_embed_temporal_dim - 1) * (visual_embed_spatial_dim ** 2), inner_dim)
        self.text_to_out = nn.Linear(inner_dim, inner_dim)

        self.qk_scale = inner_dim**-0.5
        self.final_scale = 1 / 0.07
        

    def extra_repr(self):
        return f"(Module Info) LatentRectify"

    def forward(self, visual_embed, text_embed, timesteps, include_static=False, include_shuffle=False):
        batch_size = visual_embed.shape[0]
        video_length = visual_embed.shape[2]
        if include_static:
            static_visual_embed = repeat(visual_embed[:batch_size,:,0,:,:], "b c h w -> b c f h w", f=video_length)
            visual_embed = torch.cat([visual_embed, static_visual_embed], dim=0)
            text_embed = torch.cat([text_embed, text_embed[:batch_size]], dim=0)
            timesteps = torch.cat([timesteps, timesteps[:batch_size]], dim=0)
        if include_shuffle:
            shuffle_idx = torch.randperm(visual_embed.shape[2])
            shuffle_visual_embed = visual_embed[:batch_size,:,shuffle_idx]
            visual_embed = torch.cat([visual_embed, shuffle_visual_embed], dim=0)
            text_embed = torch.cat([text_embed, text_embed[:batch_size]], dim=0)
            timesteps = torch.cat([timesteps, timesteps[:batch_size]], dim=0)
        ### first frame embed
        first_visual_embed = rearrange(visual_embed[:,:,0,:,:], "b c h w -> b (h w) c")
        ### visual cross frame corr
        visual_embed = rearrange(visual_embed, "b c f h w -> b f (h w) c")
        visual_embed_ref = F.normalize(repeat(visual_embed[:,0,:,:], "b n c -> (b f) n c", f=video_length-1), dim=-1)
        visual_embed_post = F.normalize(rearrange(visual_embed[:,1:,:,:], "b f n c -> (b f) n c"), dim=-1)
        visual_corr = rearrange(torch.bmm(visual_embed_ref, visual_embed_post.transpose(-1,-2)), "(b f) m n -> b m (f n)", f=video_length-1)
        ### time embed
        time_embed = self.nonlinearity(self.time_embedding(self.time_proj(timesteps)))
        ### mix text and time embed
        text_embed = self.norm1(text_embed + time_embed.unsqueeze(1))
        ### transformer for visual
        visual_q = self.visual_to_q(visual_corr)
        visual_k = self.visual_to_k(text_embed)
        visual_attn_prob = (self.qk_scale * torch.bmm(visual_q, visual_k.transpose(-1,-2))).softmax(dim=-1)
        visual_v = self.visual_to_v(text_embed)
        visual_final = torch.bmm(visual_attn_prob, visual_v)
        visual_final = self.visual_to_out(visual_final)
        visual_final = visual_final.mean(1)
        ### transformer for text
        text_q = self.text_to_q(text_embed)
        text_k = self.text_to_k(first_visual_embed)
        text_attn_prob = (self.qk_scale * torch.bmm(text_q, text_k.transpose(-1,-2))).softmax(dim=-1)
        text_v = self.text_to_v(visual_corr)
        text_final = torch.bmm(text_attn_prob, text_v)
        text_final = self.text_to_out(text_final)
        text_final = text_final.mean(1)
        ### score
        score = self.final_scale * (visual_final @ text_final.t())
        if not include_static and not include_shuffle:
            ref = torch.eye(batch_size).to(score.device)
            loss = F.cross_entropy(score, ref)
        else:
            score = score[:,:batch_size]
            onehot_matrix = torch.eye(batch_size).to(score.device)
            random_matrix = torch.ones_like(onehot_matrix) / batch_size
            zero_matrix = torch.zeros_like(onehot_matrix)
            visual_ref = onehot_matrix
            text_ref = onehot_matrix
            if include_static:
                visual_ref = torch.cat([visual_ref, text_ref], dim=0)
                text_ref = torch.cat([text_ref, zero_matrix], dim=1)
            if include_shuffle:
                visual_ref = torch.cat([visual_ref, random_matrix], dim=0)
                text_ref = torch.cat([text_ref, zero_matrix], dim=1)
            loss = (F.cross_entropy(score, visual_ref) + F.cross_entropy(score.T, text_ref)) / 2
        return loss, score
        

class MotionPredictorV0(nn.Module):
    def __init__(
            self,
            length = 16,
            dim = 320,
        ):
        super().__init__()
        self.adjust_trans = nn.Parameter(torch.zeros(1, length, dim))

    def extra_repr(self):
        return f"(Module Info) MotionPredictorV0"

    def forward(self, visual_embed, textual_embed, video_length, spatial_seq_length):
        length = visual_embed.shape[1]
        
        return self.adjust_trans[:,:length,:]


class MotionPredictorV1(nn.Module):
    def __init__(
            self,
            textual_embed_dim = 768,
            visual_embed_dim = 320,
            inner_dim = 128,
        ):
        super().__init__()
        self.textual_copy_trans = nn.Linear(textual_embed_dim, visual_embed_dim, bias=False)
        self.visual_copy_trans = nn.Linear(visual_embed_dim, visual_embed_dim, bias=False)
        self.textual_trans = nn.Linear(textual_embed_dim, inner_dim)
        self.visual_trans = nn.Linear(visual_embed_dim, inner_dim)
        self.adjust_trans = nn.Linear(2*inner_dim, visual_embed_dim)
        self.nonlinearity = lambda x: F.silu(x)

        self.qk_scale = visual_embed_dim**-0.5

    def extra_repr(self):
        return f"(Module Info) MotionPredictorV1"

    def forward(self, visual_embed, textual_embed, video_length, spatial_seq_length):
        query = self.visual_copy_trans(visual_embed)
        key = self.textual_copy_trans(textual_embed)
        textual_embed = self.nonlinearity(self.textual_trans(textual_embed))
        visual_embed = self.nonlinearity(self.visual_trans(visual_embed))
        attn_score = torch.bmm(query, key.transpose(-1,-2)) * self.qk_scale
        attn_prob = attn_score.softmax(-1)
        attn_embed = torch.bmm(attn_prob, textual_embed)
        cat_embed = torch.cat([visual_embed, attn_embed], dim=-1)
        motion_adjust = self.adjust_trans(cat_embed)
        
        return motion_adjust


class MotionPredictorV2(nn.Module):
    def __init__(
            self,
            visual_embed_dim = 320,
            num_motion_vector = 16,
            motion_vector_length = 24,
            constraint = '',
        ):
        super().__init__()
        self.motion_vector = nn.Parameter(torch.randn(num_motion_vector, motion_vector_length, visual_embed_dim))
        self.visual_trans = nn.Linear(visual_embed_dim, num_motion_vector)

        self.softmax_temp = 1
        self.constraint = constraint

    def extra_repr(self):
        return f"(Module Info) MotionPredictorV2"

    def forward(self, visual_embed, textual_embed, video_length, spatial_seq_length):
        bs, length = visual_embed.shape[0], visual_embed.shape[1]
        if self.constraint == '' or not self.training:
            weight = self.visual_trans(visual_embed.mean(1))
        elif self.constraint == 'first':
            weight = torch.cat([self.visual_trans(visual_embed[:bs//2].mean(1))]*2,dim=0)
        elif self.constraint == 'last':
            weight = torch.cat([self.visual_trans(visual_embed[bs//2:].mean(1))]*2,dim=0)
        elif self.constraint == 'random':
            if random.uniform(0, 1) <= 0.5:
                weight = torch.cat([self.visual_trans(visual_embed[:bs//2].mean(1))]*2,dim=0)
            else:
                weight = torch.cat([self.visual_trans(visual_embed[bs//2:].mean(1))]*2,dim=0)
        elif self.constraint == 'swap':
            if random.uniform(0, 1) <= 0.5:
                weight = self.visual_trans(visual_embed.mean(1))
            else:
                tmp_weight = self.visual_trans(visual_embed.mean(1))
                tmp_weight_1, tmp_weight_2 = torch.chunk(tmp_weight, 2)
                weight = torch.cat([tmp_weight_2, tmp_weight_1], dim=0)
        else:
            assert False, "Not implemented"
        weight = F.softmax(self.softmax_temp * weight, dim=-1)
        motion_adjust = torch.einsum('xxx,ntc->btc',weight,self.motion_vector[:,:length,:])
        return motion_adjust


class MotionPredictorV3(nn.Module):
    def __init__(
            self,
            textual_embed_dim = 768,
            visual_embed_dim = 320,
            inner_dim = 128,
            spatial_dim = 16,
        ):
        super().__init__()
        self.spatial_dim = spatial_dim
        self.textual_copy_trans = nn.Linear(textual_embed_dim, visual_embed_dim, bias=False)
        self.visual_copy_trans = nn.Linear(visual_embed_dim, visual_embed_dim, bias=False)
        self.textual_trans = nn.Linear(textual_embed_dim, inner_dim)
        self.visual_trans = nn.Linear(visual_embed_dim, inner_dim)
        self.adjust_trans = nn.Linear(2*inner_dim, spatial_dim**2)
        self.nonlinearity = lambda x: F.silu(x)

        self.qk_scale = visual_embed_dim**-0.5

    def extra_repr(self):
        return f"(Module Info) MotionPredictorV3"

    def forward(self, visual_embed, textual_embed, video_length, spatial_seq_length):
        spatial_dim = math.isqrt(spatial_seq_length)
        assert spatial_seq_length == spatial_dim ** 2
        assert visual_embed.shape[1] == spatial_seq_length * video_length
        assert visual_embed.shape[0] == textual_embed.shape[0]

        visual_embed = rearrange(visual_embed, "b (d f) c -> b d f c", f=video_length)
        ref_visual_embed = visual_embed[:,:,0,:]
        visual_embed = rearrange(visual_embed, "b d f c -> b (d f) c")

        query = self.visual_copy_trans(visual_embed)
        key = self.textual_copy_trans(textual_embed)
        textual_embed = self.nonlinearity(self.textual_trans(textual_embed))
        visual_embed = self.nonlinearity(self.visual_trans(visual_embed))
        attn_score = torch.bmm(query, key.transpose(-1,-2)) * self.qk_scale
        attn_prob = attn_score.softmax(-1)
        attn_embed = torch.bmm(attn_prob, textual_embed)
        cat_embed = torch.cat([visual_embed, attn_embed], dim=-1)
        agg_weight = self.adjust_trans(cat_embed)
        if agg_weight.shape[-1] != spatial_seq_length:
            agg_weight = rearrange(agg_weight, "b c (m n) -> b c m n", m=spatial_dim)
            agg_weight = F.interpolate(agg_weight, (spatial_dim, spatial_dim))
            agg_weight = rearrange(agg_weight, "b c m n -> b c (m n)")
            assert agg_weight.shape[-1] == spatial_seq_length
        agg_feat = torch.bmm(agg_weight, ref_visual_embed)
        
        return agg_feat


class MotionPredictorV4(nn.Module):
    def __init__(
            self,
            visual_embed_dim = 320,
            num_motion_vector = 16,
            motion_vector_length = 24,
            constraint = '',
        ):
        super().__init__()
        self.motion_vector = nn.Parameter(torch.randn(num_motion_vector, motion_vector_length, visual_embed_dim))
        self.visual_trans = nn.Linear(visual_embed_dim, num_motion_vector)

        self.softmax_temp = 1
        self.constraint = constraint

    def extra_repr(self):
        return f"(Module Info) MotionPredictorV4"

    def forward(self, visual_embed, textual_embed, video_length, spatial_seq_length):
        bs, length = visual_embed.shape[0], visual_embed.shape[1]
        if self.constraint == '' or not self.training:
            weight = self.visual_trans(visual_embed[:,0,:])
        elif self.constraint == 'first':
            weight = torch.cat([self.visual_trans(visual_embed[:bs//2][:,0,:])]*2,dim=0)
        elif self.constraint == 'last':
            weight = torch.cat([self.visual_trans(visual_embed[bs//2:][:,0,:])]*2,dim=0)
        elif self.constraint == 'random':
            if random.uniform(0, 1) <= 0.5:
                weight = torch.cat([self.visual_trans(visual_embed[:bs//2][:,0,:])]*2,dim=0)
            else:
                weight = torch.cat([self.visual_trans(visual_embed[bs//2:][:,0,:])]*2,dim=0)
        elif self.constraint == 'swap':
            if random.uniform(0, 1) <= 0.5:
                weight = self.visual_trans(visual_embed[:,0,:])
            else:
                tmp_weight = self.visual_trans(visual_embed[:,0,:])
                tmp_weight_1, tmp_weight_2 = torch.chunk(tmp_weight, 2)
                weight = torch.cat([tmp_weight_2, tmp_weight_1], dim=0)
        else:
            assert False, "Not implemented"
        weight = F.softmax(self.softmax_temp * weight, dim=-1)
        motion_adjust = torch.einsum('xxx,ntc->btc',weight,self.motion_vector[:,:length,:])
        return motion_adjust


def random_swap_motion_vectors(hidden_states, prob=0.0):
    if prob > 0.0 and prob <= 1.0 and random.uniform(0, 1) <= prob: # swap
        bsz, nf, nc = hidden_states.shape
        feature_mean = torch.mean(hidden_states, dim=1, keepdim=True)
        feature_std = torch.std(hidden_states, dim=1, keepdim=True)
        motion_vectors = (hidden_states - feature_mean) / (feature_std + 1e-6)
        motion_vectors_1, motion_vectors_2 = torch.chunk(motion_vectors, 2)
        new_motion_vectors = torch.cat([motion_vectors_2, motion_vectors_1], dim=0)
        hidden_states = new_motion_vectors * (feature_std + 1e-6) + feature_mean
    elif prob > 1.0: # keep the first half
        bsz, nf, nc = hidden_states.shape
        feature_mean = torch.mean(hidden_states, dim=1, keepdim=True)
        feature_std = torch.std(hidden_states, dim=1, keepdim=True)
        motion_vectors = (hidden_states - feature_mean) / (feature_std + 1e-6)
        motion_vectors_1, _ = torch.chunk(motion_vectors, 2)
        new_motion_vectors = torch.cat([motion_vectors_1]*2, dim=0)
        hidden_states = new_motion_vectors * (feature_std + 1e-6) + feature_mean
    return hidden_states


def random_swap_motion_vectors_v1(hidden_states, prob=0.0, percentile=0.9):
    bsz, nf, nc = hidden_states.shape
    feature_mean = torch.mean(hidden_states, dim=1, keepdim=True)
    feature_std = torch.std(hidden_states, dim=1, keepdim=True)
    channel_std = torch.mean(feature_std, dim=0, keepdim=True)
    std_thre = torch.quantile(channel_std.float(), torch.tensor([percentile]).to(channel_std.device)).clone().detach()
    channel_mask = torch.where(channel_std > std_thre, 1.0, 0.0).detach()
    
    if prob > 0.0 and prob <= 1.0 and random.uniform(0, 1) <= prob: # swap
        motion_vectors = (hidden_states - feature_mean) / (feature_std + 1e-6)
        motion_vectors_1, motion_vectors_2 = torch.chunk(motion_vectors, 2)
        new_motion_vectors = torch.cat([motion_vectors_2, motion_vectors_1], dim=0)
        new_motion_vectors = channel_mask * new_motion_vectors + (1. - channel_mask) * motion_vectors
        hidden_states = new_motion_vectors * (feature_std + 1e-6) + feature_mean
    elif prob > 1.0 and prob <= 2.0: # keep the first half
        motion_vectors = (hidden_states - feature_mean) / (feature_std + 1e-6)
        motion_vectors_1, _ = torch.chunk(motion_vectors, 2)
        new_motion_vectors = torch.cat([motion_vectors_1]*2, dim=0)
        new_motion_vectors = channel_mask * new_motion_vectors + (1. - channel_mask) * motion_vectors
        hidden_states = new_motion_vectors * (feature_std + 1e-6) + feature_mean
    elif prob > 2.0: # keep the second half
        motion_vectors = (hidden_states - feature_mean) / (feature_std + 1e-6)
        _, motion_vectors_2 = torch.chunk(motion_vectors, 2)
        new_motion_vectors = torch.cat([motion_vectors_2]*2, dim=0)
        new_motion_vectors = channel_mask * new_motion_vectors + (1. - channel_mask) * motion_vectors
        hidden_states = new_motion_vectors * (feature_std + 1e-6) + feature_mean

    return hidden_states


def random_swap_motion_vectors_v2(hidden_states, prob=0.0):
    bsz, nf, nc = hidden_states.shape
    feature_mean = torch.mean(hidden_states, dim=1, keepdim=True)
    feature_std = torch.std(hidden_states, dim=1, keepdim=True)
    spatial_std = torch.chunk(torch.mean(feature_std, dim=-1, keepdim=True),2,dim=0)
    spatial_std = (spatial_std[0] + spatial_std[1]) / 2
    std_thre = torch.quantile(spatial_std.float(), torch.tensor([.5]).to(spatial_std.device)).clone().detach()
    spatial_mask = torch.where(spatial_std > std_thre, 1.0, 0.0).detach()
    spatial_mask = torch.cat([spatial_mask, spatial_mask], dim=0)
    
    if prob > 0.0 and prob <= 1.0 and random.uniform(0, 1) <= prob: # swap
        motion_vectors = (hidden_states - feature_mean) / (feature_std + 1e-6)
        motion_vectors_1, motion_vectors_2 = torch.chunk(motion_vectors, 2)
        new_motion_vectors = torch.cat([motion_vectors_2, motion_vectors_1], dim=0)
        new_motion_vectors = spatial_mask * new_motion_vectors + (1. - spatial_mask) * motion_vectors
        hidden_states = new_motion_vectors * (feature_std + 1e-6) + feature_mean
    elif prob > 1.0: # keep the first half
        motion_vectors = (hidden_states - feature_mean) / (feature_std + 1e-6)
        motion_vectors_1, _ = torch.chunk(motion_vectors, 2)
        new_motion_vectors = torch.cat([motion_vectors_1]*2, dim=0)
        new_motion_vectors = spatial_mask * new_motion_vectors + (1. - spatial_mask) * motion_vectors
        hidden_states = new_motion_vectors * (feature_std + 1e-6) + feature_mean

    return hidden_states


class cross_frame_corr(nn.Module):
    def __init__(self, spatial_size=16, temporal_size=16, apply_mean_filtering=False, apply_local_calculation=False, apply_intra_frame_ref=False, apply_adjacent_frame_ref=True, apply_first_frame_ref=False, apply_masking=False):
        super().__init__()
        self.spatial_size = spatial_size
        self.temporal_size = temporal_size
        self.apply_mean_filtering = apply_mean_filtering
        self.apply_local_calculation = apply_local_calculation
        self.apply_intra_frame_ref = apply_intra_frame_ref
        self.apply_adjacent_frame_ref = apply_adjacent_frame_ref
        self.apply_first_frame_ref = apply_first_frame_ref
        self.apply_masking = apply_masking

    def calculate_corr(self, video):
        # video: b c f h w
        c, f, h, w = video.shape[-4], video.shape[-3], video.shape[-2], video.shape[-1]
        if not (self.temporal_size == f and self.spatial_size == h and self.spatial_size == w):
            video = F.interpolate(video, (self.temporal_size, self.spatial_size, self.spatial_size), mode='trilinear')
        if self.apply_mean_filtering:
            video = rearrange(video, "b c f h w -> (b f) c h w")
            video = F.conv2d(video, 
                        weight=(torch.ones(c, 1, 3, 3) / (3**2)).to(video.device).to(video.dtype), 
                        padding='same',
                        groups=c,
                    )
            video = rearrange(video, "(b f) c h w -> b c f h w", f=self.temporal_size)
        if self.apply_masking and hasattr(global_utils, 'corr_mask') and getattr(global_utils, 'corr_mask') is None:
            temp_frame_diff = torch.norm(video[:,:,:-1,:,:] - video[:,:,1:,:,:], p=2, dim=1, keepdim=True)
            temp_frame_diff = rearrange(temp_frame_diff, "b c f h w -> b f (h w) c")
            temp_frame_diff = torch.cat([torch.zeros_like(temp_frame_diff[:,0,:,:].unsqueeze(1)), temp_frame_diff], dim=1)
        video_ref = []
        if self.apply_masking and hasattr(global_utils, 'corr_mask') and getattr(global_utils, 'corr_mask') is None:
            frame_diff = []
        if self.apply_intra_frame_ref:
            video_ref.append(F.normalize(rearrange(video, "b c f h w -> b f (h w) c"), dim=-1))
            if self.apply_masking and hasattr(global_utils, 'corr_mask') and getattr(global_utils, 'corr_mask') is None:
                frame_diff.append(temp_frame_diff)
        if self.apply_adjacent_frame_ref:
            video_ref.append(F.normalize(rearrange(video[:,:,:-1,:,:], "b c f h w -> b f (h w) c"), dim=-1))
            if self.apply_masking and hasattr(global_utils, 'corr_mask') and getattr(global_utils, 'corr_mask') is None:
                frame_diff.append(temp_frame_diff[:,1:,:,:])
        if self.apply_first_frame_ref:
            video_ref.append(F.normalize(rearrange(repeat(video[:,:,0,:,:], "b c h w -> b c f h w", f=self.temporal_size), "b c f h w -> b f (h w) c"), dim=-1))
            if self.apply_masking and hasattr(global_utils, 'corr_mask') and getattr(global_utils, 'corr_mask') is None:
                frame_diff.append(temp_frame_diff)
        video_ref = torch.cat(video_ref, dim=1)
        if self.apply_masking and hasattr(global_utils, 'corr_mask') and getattr(global_utils, 'corr_mask') is None:
            frame_diff = torch.cat(frame_diff, dim=1)
            global_utils.corr_mask = frame_diff
        video_tgt = []
        if self.apply_intra_frame_ref:
            video_tgt.append(F.normalize(rearrange(video, "b c f h w -> b f (h w) c"), dim=-1))
        if self.apply_adjacent_frame_ref:
            video_tgt.append(F.normalize(rearrange(video[:,:,1:,:,:], "b c f h w -> b f (h w) c"), dim=-1))
        if self.apply_first_frame_ref:
            video_tgt.append(F.normalize(rearrange(video, "b c f h w -> b f (h w) c"), dim=-1))
        video_tgt = torch.cat(video_tgt, dim=1)
        corr = einsum(video_ref.detach(), video_tgt, "b f m c, b f n c -> b f m n") # 4.7: F.softmax(dim=-1)
        if self.apply_local_calculation:
            mesh_ind = torch.arange(self.spatial_size*self.spatial_size).unsqueeze(-1).unsqueeze(0).unsqueeze(0).repeat(corr.shape[0], corr.shape[1],1,1)
            w_ind = mesh_ind % self.spatial_size
            h_ind = mesh_ind // self.spatial_size
            local_win_size = 7
            local_delta_h=torch.arange(-(local_win_size//2),local_win_size//2+1).reshape(local_win_size,1).repeat(1,local_win_size).view(1,1,1,-1).repeat(corr.shape[0], corr.shape[1],1,1)
            local_delta_w=torch.arange(-(local_win_size//2),local_win_size//2+1).reshape(1,local_win_size).repeat(local_win_size,1).view(1,1,1,-1).repeat(corr.shape[0], corr.shape[1],1,1)
            gather_ind = (torch.clamp(h_ind+local_delta_h,0,self.spatial_size-1) * self.spatial_size + torch.clamp(w_ind+local_delta_w,0,self.spatial_size-1)).to(corr.device)
            corr = torch.gather(corr, 3, gather_ind)
        return corr



class MaskPredictor(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(MaskPredictor, self).__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 8, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(8, 16, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Conv2d(16, 8, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(8, 1, kernel_size=3, padding=1),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x