import sys

import torch
from torch import nn
# sys.path.append("../")

#  EMA ， nn.Module。
class EMA(nn.Module):
    def __init__(self, channels, c2=None, factor=32):
        # 。
        super(EMA, self).__init__()
        # ， factor。
        self.groups = factor
        # 。
        assert channels // self.groups > 0
        #  softmax ，。
        self.softmax = nn.Softmax(-1)
        # ， (1,1) 。
        self.agp = nn.AdaptiveAvgPool2d((1, 1))
        # ， 1，。
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        # ， 1，。
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))
        # ， channels // groups。
        self.gn = nn.GroupNorm(channels // self.groups, channels // self.groups)
        # 1x1 ，。
        self.conv1x1 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=1, stride=1, padding=0)
        # 3x3 ，。
        self.conv3x3 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        #  x  (batch_size, channels, height, width)。
        b, c, h, w = x.size()
        #  x  (batch_size * groups, channels // groups, height, width)。
        group_x = x.reshape(b * self.groups, -1, h, w)
        # ， (batch_size * groups, channels // groups, height, 1) 。
        x_h = self.pool_h(group_x)
        # ， (batch_size * groups, channels // groups, 1, width) ，。
        x_w = self.pool_w(group_x).permute(0, 1, 3, 2)
        # ， 1x1 ，。
        hw = self.conv1x1(torch.cat([x_h, x_w], dim=2))
        # ， x_h  x_w。
        x_h, x_w = torch.split(hw, [h, w], dim=2)
        #  sigmoid  x_h, x_w  group_x 。
        x1 = self.gn(group_x * x_h.sigmoid() * x_w.permute(0, 1, 3, 2).sigmoid())
        #  3x3  group_x 。
        x2 = self.conv3x3(group_x)
        #  x1  softmax，。
        x11 = self.softmax(self.agp(x1).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
        #  x2  (batch_size * groups, channels // groups, height * width) 。
        x12 = x2.reshape(b * self.groups, c // self.groups, -1)
        #  x2  softmax，。
        x21 = self.softmax(self.agp(x2).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
        #  x1  (batch_size * groups, channels // groups, height * width) 。
        x22 = x1.reshape(b * self.groups, c // self.groups, -1)
        #  x11  x12, x21  x22 ， reshape  (batch_size * groups, 1, height, width)。
        weights = (torch.matmul(x11, x12) + torch.matmul(x21, x22)).reshape(b * self.groups, 1, h, w)
        #  group_x ， reshape  (batch_size, channels, height, width)。
        return (group_x * weights.sigmoid()).reshape(b, c, h, w)