import torch.nn as nn
    
class Spacial_wise_Conv(nn.Module):
    def __init__(self, dim, kernel, res_flag):
        super(Spacial_wise_Conv, self).__init__()        
        blocks = []
        
        for l in range(len(kernel)):
            block = ConvResBlock(dim, kernel=kernel[l], res_flag=res_flag[l])
            blocks.append(block)
            
        self.blocks = nn.ModuleList(blocks)
        
    def forward(self, x):
        B, L, H, W, C = x.shape
            
        
        # B, L, H, W, C -> BL, C, H, W
        x = x.permute(0, 1, 4, 2, 3).view(-1, C, H, W)
        for i in range(len(self.blocks)):                
            x = self.blocks[i](x)
       
        return x.view(B, L, -1, H, W).permute(0, 1, 3, 4, 2)

class ConvResBlock(nn.Module):
    def __init__(self, dim, kernel, res_flag):
        super(ConvResBlock, self).__init__()
        self.res_flag = res_flag
        
        self.conv = nn.Conv2d(dim, dim, kernel, padding=kernel//2)
        self.norm = nn.BatchNorm2d(dim)
        self.act = nn.ReLU()

    def forward(self, x):
        out = self.conv(x)
        out = self.norm(out) + x if self.res_flag else self.norm(out)
        out = self.act(out)
        return out
