import math
import numpy as np

import torch.nn as nn

    
class Heterogeneity_reflected_Convolution(nn.Module):
    def __init__(self, dim, kernel, hetero_flag):
        super(Heterogeneity_reflected_Convolution, self).__init__()
        self.hetero_flag = hetero_flag
        
        blocks = []
        
        for l in range(len(kernel)):
            if hetero_flag[l]:
                block = HeteroConvResBlock(dim, kernel=kernel[l])
            else:
                block = ConvResBlock(dim, kernel=kernel[l])
            blocks.append(block)
            
        self.blocks = nn.ModuleList(blocks)
        
    def forward(self, x, infra):
        B, L, H, W, C = x.shape
        
        x = x.permute(0, 1, 4, 2, 3).view(-1, C, H, W)
        infra_ = infra.expand(H, W, C, B, L).permute(3, 4, 2, 0, 1).view(-1, C, H, W)
    
        res_flag = True
        for i in range(len(self.blocks)):
            if i == len(self.blocks) - 1:
                res_flag = False
                
            if self.hetero_flag[i]:
                x = self.blocks[i](x, infra_, res_flag)
            else:
                x = self.blocks[i](x, res_flag)
       
        return x.view(B, L, -1, H, W).permute(0, 1, 3, 4, 2)

class ConvResBlock(nn.Module):
    def __init__(self, dim, kernel):
        super(ConvResBlock, self).__init__()
        self.conv = nn.Conv2d(dim, dim, kernel, padding=kernel//2)
        self.norm = nn.BatchNorm2d(dim)
        self.act = nn.ReLU()

    def forward(self, x, res_flag=True):
        out = self.conv(x)
        out = self.norm(out) + x if res_flag else self.norm(out)
        out = self.act(out)
        return out
    
class HeteroConvResBlock(nn.Module):
    def __init__(self, dim, kernel):
        super(HeteroConvResBlock, self).__init__()        
        self.i_conv = nn.Conv2d(dim, dim, kernel, padding=kernel//2, bias=False)
        self.v_conv = nn.Conv2d(dim, dim, kernel, padding=kernel//2, bias=False)
        
        self.norm = nn.BatchNorm2d(dim)
        self.act = nn.ReLU()

    def forward(self, x, infra_, res_flag=True):
        i_x = x.masked_fill(infra_ == 0, 0)
        v_x = x.masked_fill(infra_ == 1, 0)
        
        i_out = self.i_conv(i_x)
        v_out = self.v_conv(v_x)
        
        out = self.norm(i_out + v_out) + x if res_flag else self.norm(i_out + v_out)        
        out = self.act(out)
        
        return out