from torch import nn
import einops
import torch

__all__ = ['gsca_mobilenet_v2']

model_urls = {

}


class Crossattention(nn.Module):
    def __init__(self, heads=1):
        super(Crossattention, self).__init__()
        self.heads = heads
        self.act = nn.Sigmoid()

    def forward(self, x):
        x = einops.rearrange(x, 'b (he c) h w ->b he c h w', he=self.heads)
        xmean = x.mean(dim=2, keepdim=True)
        xmeancha = x - xmean
        xdev = torch.sqrt(xmeancha.pow(2).sum(dim=2, keepdim=True))  # √ ∑(xi - xm)²

        stand = x.mean([-2, -1], keepdim=True)
        standmean = stand.mean(dim=2, keepdim=True)
        standcha = stand - standmean

        standdev = torch.sqrt(standcha.pow(2).sum(dim=2, keepdim=True))

        qiuji = xmeancha * standcha
        jiahe = qiuji.sum(dim=2, keepdim=True)
        res = jiahe / (xdev * standdev)

        atten = (1 - self.act(res)) ** 2

        out = x * atten
        out = einops.rearrange(out, 'b he c h w -> b (he c) h w', he=self.heads)

        return out


class ConvBNReLU(nn.Sequential):
    def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
        padding = (kernel_size - 1) // 2
        super(ConvBNReLU, self).__init__(
            nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
            nn.BatchNorm2d(out_planes),
            nn.ReLU6(inplace=True)
        )


class Atten_ConvBNReLU(nn.Sequential):
    def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, atten_heads=1):
        padding = (kernel_size - 1) // 2
        super(Atten_ConvBNReLU, self).__init__(
            nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
            nn.BatchNorm2d(out_planes),
            Crossattention(heads=atten_heads),
            nn.ReLU6(inplace=True)
        )


class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, expand_ratio, atten_heads):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = int(round(inp * expand_ratio))
        self.use_res_connect = self.stride == 1 and inp == oup

        layers = []
        if expand_ratio != 1:
            # pw
            layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
        layers.extend([
            # dw
            Atten_ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, atten_heads=atten_heads),
            # ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),

            # pw-linear
            nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
            nn.BatchNorm2d(oup),
        ])
        self.conv = nn.Sequential(*layers)

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)


class GSCA_MobileNetV2(nn.Module):
    def __init__(self, num_classes=1000, width_mult=1.0):
        super(GSCA_MobileNetV2, self).__init__()
        block = InvertedResidual
        input_channel = 32
        last_channel = 1280
        inverted_residual_setting = [
            # t, c, n, s, heads
            [1, 16, 1, 1, 1],
            [6, 24, 2, 2, 1],
            [6, 32, 3, 2, 2],
            [6, 64, 4, 2, 3],
            [6, 96, 3, 1, 4],
            [6, 160, 3, 2, 8],
            [6, 320, 1, 1, 15],
        ]

        # building first layer
        input_channel = int(input_channel * width_mult)
        self.last_channel = int(last_channel * max(1.0, width_mult))
        features = [ConvBNReLU(3, input_channel, stride=2)]
        # building inverted residual blocks
        for t, c, n, s, heads in inverted_residual_setting:
            output_channel = int(c * width_mult)
            for i in range(n):
                stride = s if i == 0 else 1
                features.append(block(input_channel, output_channel, stride, expand_ratio=t, atten_heads=heads))
                input_channel = output_channel
        # building last several layers
        features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
        # make it nn.Sequential
        self.features = nn.Sequential(*features)

        # building classifier
        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(self.last_channel, num_classes),
        )

        # weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x):
        x = self.features(x)
        x = x.mean(-1).mean(-1)
        x = self.classifier(x)
        return x


def gsca_mobilenet_v2(pretrained=False, progress=True, **kwargs):
    """
    Constructs a ECA_MobileNetV2 architecture from

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    model =GSCA_MobileNetV2(**kwargs)
    # if pretrained:
    #     state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
    #                                           progress=progress)
    #     model.load_state_dict(state_dict)
    return model


if __name__ == '__main__':
    from thop import profile
    from thop import clever_format
    import torch

    net = gsca_mobilenet_v2()
    print(net)
    x = torch.randn(1, 3, 224, 224)
    macs, params = profile(net, inputs=(x,))
    print(net(x).shape)
    print('macs:', macs, 'params:', params)  # 16865803776.0 3206976.0
    print('--------')
    macs, params = clever_format([macs, params], "%.3f")
    print('macs:', macs, 'params:', params)  # 16.866G 3.207M




