import torch
import torch.nn as nn
import torch.nn.functional as F

class EMA3D(nn.Module):
    def __init__(self, channels, factor=32):
        super(EMA3D, self).__init__()
        self.groups = factor
        assert channels // self.groups > 0
        self.softmax = nn.Softmax(-1)
        self.agp = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.pool_d = nn.AdaptiveAvgPool3d((None, 1, 1))
        self.pool_h = nn.AdaptiveAvgPool3d((1, None, 1))
        self.pool_w = nn.AdaptiveAvgPool3d((1, 1, None))
        
        self.gn = nn.GroupNorm(channels // self.groups, channels // self.groups)
        self.conv1x1x1 = nn.Conv3d(channels // self.groups, channels // self.groups, kernel_size=1, stride=1, padding=0)
        self.conv3x3x3 = nn.Conv3d(channels // self.groups, channels // self.groups, kernel_size=3, stride=1, padding=1)
    
    def forward(self, x):
        b, c, d, h, w = x.size()
        
        group_x = x.reshape(b * self.groups, -1, d, h, w)
        
        x_d = self.pool_d(group_x)
        x_h = self.pool_h(group_x)
        x_w = self.pool_w(group_x)
        
        x_h = x_h.permute(0, 1, 3, 2, 4)
        x_w = x_w.permute(0, 1, 4, 2, 3)
        
        x_cat = torch.cat([x_d, x_h, x_w], dim=2)
        
        dhw = self.conv1x1x1(x_cat.squeeze(-1).squeeze(-1).unsqueeze(-1).unsqueeze(-1))
        
        x_d, x_h, x_w = torch.split(dhw.squeeze(-1).squeeze(-1), [d, h, w], dim=2)
        
        x_d = x_d.unsqueeze(-1).unsqueeze(-1)
        x_h = x_h.unsqueeze(2).unsqueeze(-1)
        x_w = x_w.unsqueeze(2).unsqueeze(3)
        
        x1 = self.gn(group_x * x_d.sigmoid() * x_h.sigmoid() * x_w.sigmoid())
        
        x2 = self.conv3x3x3(group_x)
        
        x11 = self.softmax(self.agp(x1).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
        
        x12 = x2.reshape(b * self.groups, c // self.groups, -1)
        
        x21 = self.softmax(self.agp(x2).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
        
        x22 = x1.reshape(b * self.groups, c // self.groups, -1)
        
        weights = (torch.matmul(x11, x12) + torch.matmul(x21, x22)).reshape(b * self.groups, 1, d, h, w)
        
        return (group_x * weights.sigmoid()).reshape(b, c, d, h, w)
