import torch
import torch.nn as nn
from einops import rearrange
    
class Heterogeneity_reflected_Convolution_V2(nn.Module):
    def __init__(self, dim, output_m, ablation, conv):
        super(Heterogeneity_reflected_Convolution_V2, self).__init__()

        self.output_m = output_m
        self.ablation = ablation
        
        self.ho_c_layer = nn.Sequential(nn.Conv2d(2, 1, 3, 1, 1, bias=False),
                                        nn.BatchNorm2d(1))
        
        self.ho_a_layer = nn.Sequential(nn.Conv2d(dim * 2, dim, 3, 1, 1, bias=False),
                                        nn.BatchNorm2d(dim))
        
        self.act = nn.ReLU()
        self.sig = nn.Sigmoid()
    
    def generate_s_mask(self, infra, regroup_s_mask):
        B, L, H, W = regroup_s_mask.shape
        
        he_o_list = []
        ho_o_list = []
                
        for b in range(B):
            he_o = torch.zeros((L, H, W), dtype=torch.bool, device=regroup_s_mask.device)
            ho_o = torch.zeros((L, H, W), dtype=torch.bool, device=regroup_s_mask.device)
            
            tmp_s_mask = regroup_s_mask[b, :]
            
            # if 1 in infra[b, :]:
            #     inf_index = torch.nonzero(infra[b, :] == 1).item()
            #     inf_s_mask = tmp_s_mask[infra[b, :] == 1, :].squeeze()
                
            for l in range(L-1):                
                # if 1 in infra[b, :]:
                    # if l == inf_index:
                    #     continue
                    
                    # he_o[l, :] = torch.logical_and(inf_s_mask, tmp_s_mask[l, :])
                    # he_o[inf_index, :] = torch.logical_or(he_o[inf_index, :], he_o[l, :])
                
                for i in range(l+1, L):
                    tmp_ho = torch.logical_and(tmp_s_mask[l, :], tmp_s_mask[i, :])
                    ho_o[l, :] = torch.logical_or(ho_o[l, :], tmp_ho)
                    ho_o[i, :] = torch.logical_or(ho_o[i, :], tmp_ho)   
            
            he_o_list.append(he_o)
            ho_o_list.append(ho_o)
        
        # B, L, H, W
        he_o = torch.stack(he_o_list, dim=0)
        ho_o = torch.stack(ho_o_list, dim=0)    
        
        not_o = torch.logical_and(~he_o, ~ho_o)
        
        return he_o, ho_o, not_o
        
    def forward(self, x, infra, regroup_s_mask=None):
        B, L, H, W, C = x.shape
        _, _, MH, MW = regroup_s_mask.shape
        
        # B, L, H, W, C -> B, L, C, H, W
        x = x.permute(0, 1, 4, 2, 3)
        
        if self.ablation:
            ho_x = x
            
        else:
            _, ho_o, _ = self.generate_s_mask(infra, regroup_s_mask)
                
            # B, L, H, W -> B, L, C, H, W
            # he_o = he_o.unsqueeze(2).expand(-1, -1, C, -1, -1)
            ho_o = ho_o.unsqueeze(2).expand(-1, -1, C, -1, -1)
            
            ho_x = x.masked_fill(ho_o == False, 0)
        
        c_gate = torch.concat((torch.max(ho_x, 2)[0].unsqueeze(1), torch.mean(ho_x, 2).unsqueeze(1)), dim=1).view(B * L, -1, H, W)
        c_gate = self.sig(self.ho_c_layer(c_gate).unsqueeze(2)).view(B, L, -1, H, W)
        out = x * c_gate
        
        a_gate = torch.concat((torch.max(out, 1)[0].unsqueeze(1), torch.mean(out, 1).unsqueeze(1)), dim=1).view(B, -1, H, W)
        a_gate = self.sig(self.ho_a_layer(a_gate).unsqueeze(1)).view(B, -1, C, H, W)
        out = out * a_gate
        
        out = self.act(out)
        return out.permute(0, 1, 3, 4, 2)      