from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Conv2d
# from models.dcn import DeformableConv2d as Conv2d


class FFCSE_block(nn.Module):

    def __init__(self, channels, ratio_g):
        super(FFCSE_block, self).__init__()
        in_cg = int(channels * ratio_g)
        in_cl = channels - in_cg
        r = 16

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.conv1 = nn.Conv2d(channels, channels // r,
                               kernel_size=1, bias=True)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv_a2l = None if in_cl == 0 else nn.Conv2d(
            channels // r, in_cl, kernel_size=1, bias=True)
        self.conv_a2g = None if in_cg == 0 else nn.Conv2d(
            channels // r, in_cg, kernel_size=1, bias=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = x if type(x) is tuple else (x, 0)
        id_l, id_g = x

        x = id_l if type(id_g) is int else torch.cat([id_l, id_g], dim=1)
        x = self.avgpool(x)
        x = self.relu1(self.conv1(x))

        x_l = 0 if self.conv_a2l is None else id_l * \
            self.sigmoid(self.conv_a2l(x))
        x_g = 0 if self.conv_a2g is None else id_g * \
            self.sigmoid(self.conv_a2g(x))
        return x_l, x_g


class FourierUnit(nn.Module):

    def __init__(self, in_channels, out_channels, groups=1):
        # bn_layer not used
        super(FourierUnit, self).__init__()
        self.groups = groups
        self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2, out_channels=out_channels * 2,
                                          kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False)
        self.bn = torch.nn.BatchNorm2d(out_channels * 2)
        self.relu = torch.nn.ReLU(inplace=True)

    def forward(self, x):
        batch, c, h, w = x.size()
        r_size = x.size()

        # (batch, c, h, w/2+1, 2)
        # ffted = torch.rfft(x, signal_ndim=2, normalized=True)
        fft_result = torch.fft.fft2(x, dim=(-2, -1), norm="forward")
        ffted = torch.stack((fft_result.real, fft_result.imag), -1)  # 根据需求将复数形式转成数组形式

        # (batch, c, 2, h, w/2+1)
        ffted = ffted.permute(0, 1, 4, 2, 3).contiguous()
        ffted = ffted.view((batch, -1,) + ffted.size()[3:])

        ffted = self.conv_layer(ffted)  # (batch, c*2, h, w/2+1)
        ffted = self.relu(self.bn(ffted))

        ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(0, 1, 3, 4, 2).contiguous()  # (batch,c, t, h, w/2+1, 2)
        ffted = torch.complex(ffted[..., 0], ffted[..., 1])
        # output = torch.fft.irfft(ffted, signal_ndim=2, signal_sizes=r_size[2:], normalized=True)
        output = torch.fft.irfft2(ffted, s=r_size[2:], dim=(-2, -1), norm="forward")

        return output


class SpectralTransform(nn.Module):

    def __init__(self, in_channels, out_channels, stride=1, groups=1, enable_lfu=True):
        # bn_layer not used
        super(SpectralTransform, self).__init__()
        self.enable_lfu = enable_lfu
        if stride == 2:
            self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
        else:
            self.downsample = nn.Identity()

        self.stride = stride
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels // 2, kernel_size=1, groups=groups, bias=False),
            nn.BatchNorm2d(out_channels // 2),
            nn.ReLU(inplace=True)
        )
        self.fu = FourierUnit(
            out_channels // 2, out_channels // 2, groups)
        if self.enable_lfu:
            self.lfu = FourierUnit(
                out_channels // 2, out_channels // 2, groups)
        self.conv2 = torch.nn.Conv2d(
            out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False)

    def forward(self, x):
        x = x.float()

        with torch.cuda.amp.autocast(enabled=False):
            x = self.downsample(x)
            x = self.conv1(x)
            output = self.fu(x)

            if self.enable_lfu:
                n, c, h, w = x.shape
                split_no = 2
                split_s_h = h // split_no
                split_s_w = w // split_no
                xs = torch.cat(torch.split(
                    x[:, :c // 4], split_s_h, dim=-2), dim=1).contiguous()
                xs = torch.cat(torch.split(xs, split_s_w, dim=-1),
                               dim=1).contiguous()
                xs = self.lfu(xs)
                xs = xs.repeat(1, 1, split_no, split_no).contiguous()
            else:
                xs = 0

            output = self.conv2(x + output + xs)

        output = output.half()
        return output


class FFC(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size,
                 ratio_gin, ratio_gout, stride=1, padding=0,
                 dilation=1, groups=1, bias=False, enable_lfu=True):
        super(FFC, self).__init__()

        assert stride == 1 or stride == 2, "Stride should be 1 or 2."
        self.stride = stride

        in_cg = int(in_channels * ratio_gin)
        in_cl = in_channels - in_cg
        out_cg = int(out_channels * ratio_gout)
        out_cl = out_channels - out_cg

        self.ratio_gin = ratio_gin
        self.ratio_gout = ratio_gout

        module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d
        self.convl2l = module(in_cl, out_cl, kernel_size, stride, padding, dilation, groups, bias)
        module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d
        self.convl2g = module(in_cl, out_cg, kernel_size, stride, padding, dilation, groups, bias)
        module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d
        self.convg2l = module(in_cg, out_cl, kernel_size, stride, padding, dilation, groups, bias)
        module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
        self.convg2g = module(in_cg, out_cg, stride, 1 if groups == 1 else groups // 2, enable_lfu)

    def forward(self, x):
        x_l, x_g = x if type(x) is tuple else (x, 0)
        out_xl, out_xg = 0, 0

        if self.ratio_gout != 1:
            out_xl = self.convl2l(x_l) + self.convg2l(x_g)
        if self.ratio_gout != 0:
            out_xg = self.convl2g(x_l) + self.convg2g(x_g)

        return out_xl, out_xg


class FFC_BN_ACT(nn.Module):

    def __init__(self, in_channels, out_channels,
                 kernel_size, ratio_gin, ratio_gout,
                 stride=1, padding=0, dilation=1, groups=1, bias=False,
                 norm_layer=nn.BatchNorm2d, activation_layer=nn.Identity,
                 enable_lfu=True):
        super(FFC_BN_ACT, self).__init__()
        self.ffc = FFC(in_channels, out_channels, kernel_size,
                       ratio_gin, ratio_gout, stride, padding, dilation,
                       groups, bias, enable_lfu)
        lnorm = nn.Identity if ratio_gout == 1 else norm_layer
        gnorm = nn.Identity if ratio_gout == 0 else norm_layer
        self.bn_l = lnorm(int(out_channels * (1 - ratio_gout)))
        self.bn_g = gnorm(int(out_channels * ratio_gout))

        lact = nn.Identity if ratio_gout == 1 else activation_layer
        gact = nn.Identity if ratio_gout == 0 else activation_layer
        self.act_l = lact(inplace=True)
        self.act_g = gact(inplace=True)

    def forward(self, x):
        x_l, x_g = self.ffc(x)
        x_l = self.act_l(self.bn_l(x_l))
        x_g = self.act_g(self.bn_g(x_g))
        return x_l, x_g


class MMFFC(nn.Module):
    def __init__(self, in_channels, out_channels,
                 kernel_size, ratio_gin, ratio_gout,
                 stride=1, padding=0, dilation=1, groups=1, bias=False,
                 norm_layer=nn.BatchNorm2d, activation_layer=nn.Identity,
                 enable_lfu=True):
        super(MMFFC, self).__init__()
        self.rgb_ffc = FFC(in_channels, out_channels, kernel_size,
                           ratio_gin, ratio_gout, stride, padding, dilation,
                           groups, bias, enable_lfu)
        lnorm = nn.Identity if ratio_gout == 1 else norm_layer
        gnorm = nn.Identity if ratio_gout == 0 else norm_layer
        self.bn_l = lnorm(int(out_channels * (1 - ratio_gout)))
        self.bn_g = gnorm(int(out_channels * ratio_gout))

        lact = nn.Identity if ratio_gout == 1 else activation_layer
        gact = nn.Identity if ratio_gout == 0 else activation_layer
        self.act_l = lact(inplace=True)
        self.act_g = gact(inplace=True)

        self.depth_ffc = FFC(in_channels, out_channels, kernel_size,
                           ratio_gin, ratio_gout, stride, padding, dilation,
                           groups, bias, enable_lfu)
        self.bn_depth_l = lnorm(int(out_channels * (1 - ratio_gout)))
        self.bn_depth_g = gnorm(int(out_channels * ratio_gout))
        self.act_depth_l = lact(inplace=True)
        self.act_depth_g = gact(inplace=True)

    def forward(self, x):
        x_rgb, x_depth = x
        x_rgb_l, x_rgb_g = self.rgb_ffc(x_rgb)
        x_rgb_l = self.act_l(self.bn_l(x_rgb_l))
        x_rgb_g = self.act_g(self.bn_g(x_rgb_g))

        x_depth_l, x_depth_g = self.depth_ffc(x_depth)
        x_depth_l = self.act_depth_l(self.bn_depth_l(x_depth_l))
        x_depth_g = self.act_depth_g(self.bn_depth_g(x_depth_g))

        x_rgb_l = x_rgb_l + x_depth_l
        x_rgb_g = x_rgb_g + x_depth_g

        return (x_rgb_l, x_rgb_g), (x_depth_l, x_depth_g)


class ConcatTupleLayer(nn.Module):
    def forward(self, x):
        x_rgb, x_depth = x
        assert isinstance(x_rgb, tuple)
        x_l, x_g = x_rgb
        assert torch.is_tensor(x_l) or torch.is_tensor(x_g)
        if not torch.is_tensor(x_g):
            return x_l
        return torch.cat(x_rgb, dim=1)



class SeparableConv2d(nn.Module):
    def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, relu_first=True,
                 bias=False, norm_layer=nn.BatchNorm2d):
        super().__init__()
        depthwise = nn.Conv2d(inplanes, inplanes, kernel_size,
                              stride=stride, padding=dilation,
                              dilation=dilation, groups=inplanes, bias=bias)
        bn_depth = norm_layer(inplanes)
        pointwise = nn.Conv2d(inplanes, planes, 1, bias=bias)
        bn_point = norm_layer(planes)

        if relu_first:
            self.block = nn.Sequential(OrderedDict([('relu', nn.ReLU()),
                                                    ('depthwise', depthwise),
                                                    ('bn_depth', bn_depth),
                                                    ('pointwise', pointwise),
                                                    ('bn_point', bn_point)
                                                    ]))
        else:
            self.block = nn.Sequential(OrderedDict([('depthwise', depthwise),
                                                    ('bn_depth', bn_depth),
                                                    ('relu1', nn.ReLU(inplace=True)),
                                                    ('pointwise', pointwise),
                                                    ('bn_point', bn_point),
                                                    ('relu2', nn.ReLU(inplace=True))
                                                    ]))

    def forward(self, x):
        return self.block(x)


class ImageDepthFusionModule(nn.Module):
    def __init__(self, norm_layer=nn.BatchNorm2d, inplane=256, outplane=256):
        super(ImageDepthFusionModule, self).__init__()
        self.conv1 = SeparableConv2d(inplane * 2, outplane, 3, norm_layer=norm_layer, relu_first=False)
        self.fc1 = nn.Conv2d(outplane, outplane // 16, kernel_size=1)
        self.fc2 = nn.Conv2d(outplane // 16, outplane, kernel_size=1)

        self.attn_conv = nn.Conv2d(inplane*2, 1, kernel_size=1)
        self.out_conv = nn.Conv2d(inplane*2, outplane, kernel_size=1)

    def forward(self, c, att_map):
        atted_c = c * att_map
        x = torch.cat([c, atted_c], 1)  # 512
        x = self.conv1(x)  # 256
        weight = F.avg_pool2d(x, x.size(2))
        weight = F.relu(self.fc1(weight))
        weight = torch.sigmoid(self.fc2(weight))
        x = x * weight
        return c + x

    def forward2(self, rgb, depth):
        attn = self.attn_conv(torch.cat([rgb, depth], 1))
        attn = torch.sigmoid(attn)
        rgb = rgb + attn * depth
        return rgb

    def forward3(self, rgb, depth):
        x = torch.cat([rgb, depth], 1)  # 512
        x = self.conv1(x)  # 256
        weight = F.avg_pool2d(x, x.size(2))
        weight = F.relu(self.fc1(weight))
        weight = torch.sigmoid(self.fc2(weight))
        x = x * weight
        return x


class MMFFCLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(MMFFCLayer, self).__init__()
        groups = 1
        self.rgb_local_conv = nn.Sequential(
            Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, groups=groups, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.rgb_global_conv = nn.Sequential(
            SpectralTransform(in_channels, out_channels, stride=1, groups=groups, enable_lfu=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.depth_local_conv = nn.Sequential(
            Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, groups=groups, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.depth_global_conv = nn.Sequential(
            SpectralTransform(in_channels, out_channels, stride=1, groups=groups, enable_lfu=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        self.idf = ImageDepthFusionModule(inplane=out_channels, outplane=out_channels)

    def forward(self, x):
        x_rgb, x_depth = x
        x_rgb_local = self.rgb_local_conv(x_rgb)
        x_rgb_global = self.rgb_global_conv(x_rgb)
        x_depth_local = self.depth_local_conv(x_depth)
        x_depth_global = self.depth_global_conv(x_depth)

        output = self.idf(
            x_rgb + x_rgb_local + x_rgb_global,
            x_depth + x_depth_local + x_depth_global
        ) + x_rgb

        return output


if __name__ == '__main__':
    in_channels = 3
    channels = 64
    out_channel = 64
    alpha = 0.5
    ffclayers = nn.Sequential(
        MMFFC(in_channels, channels, 3, 0, alpha, stride=1, padding=1, dilation=1, groups=1, bias=False, enable_lfu=False),
        MMFFC(channels, channels, 3, alpha, alpha, stride=1, padding=1, dilation=1, groups=1, bias=False, enable_lfu=False),
        MMFFC(channels, out_channel, 3, alpha, 0, stride=1, padding=1, dilation=1, groups=1, bias=False, enable_lfu=False),
        ConcatTupleLayer(),
    )
    # Test
    x_rgb = torch.randn(1, 3, 224, 224)
    x_depth = torch.randn(1, 3, 224, 224)
    out = ffclayers((x_rgb, x_depth))
    pass
