# Copyright (c) Facebook, Inc. and its affiliates.
import numpy as np
import fvcore.nn.weight_init as weight_init
import torch
import torch.nn.functional as F
from torch import nn

from detectron2.layers import (
    CNNBlockBase,
    Conv2d,
    DeformConv,
    ModulatedDeformConv,
    ShapeSpec,
    get_norm,
)

from detectron2.modeling.backbone.backbone import Backbone
from detectron2.modeling.backbone.build import BACKBONE_REGISTRY
import pdb
import random

__all__ = [
    "ResNetBlockBase",
    "BasicBlock",
    "BottleneckBlock",
    "DeformBottleneckBlock",
    "BasicStem",
    "ResNet",
    "make_stage",
    "build_resnet_backbone",
]

class InstanceWhitening(nn.Module):

    def __init__(self, dim):
        super(InstanceWhitening, self).__init__()
        self.instance_standardization = nn.InstanceNorm2d(dim, affine=False)

    def forward(self, x):

        x = self.instance_standardization(x)
        w = x

        return x, w


class BasicBlock(CNNBlockBase):
    """
    The basic residual block for ResNet-18 and ResNet-34 defined in :paper:`ResNet`,
    with two 3x3 conv layers and a projection shortcut if needed.
    """

    def __init__(self, in_channels, out_channels, *, stride=1, norm="BN"):
        """
        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            stride (int): Stride for the first conv.
            norm (str or callable): normalization for all conv layers.
                See :func:`layers.get_norm` for supported format.
        """
        super().__init__(in_channels, out_channels, stride)

        if in_channels != out_channels:
            self.shortcut = Conv2d(
                in_channels,
                out_channels,
                kernel_size=1,
                stride=stride,
                bias=False,
                norm=get_norm(norm, out_channels),
            )
        else:
            self.shortcut = None

        self.conv1 = Conv2d(
            in_channels,
            out_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=False,
            norm=get_norm(norm, out_channels),
        )

        self.conv2 = Conv2d(
            out_channels,
            out_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            bias=False,
            norm=get_norm(norm, out_channels),
        )

        for layer in [self.conv1, self.conv2, self.shortcut]:
            if layer is not None:  # shortcut can be None
                weight_init.c2_msra_fill(layer)

    def forward(self, x):
        out = self.conv1(x)
        out = F.relu_(out)
        out = self.conv2(out)

        if self.shortcut is not None:
            shortcut = self.shortcut(x)
        else:
            shortcut = x

        out += shortcut
        out = F.relu_(out)
        return out

class BottleneckBlock(CNNBlockBase):
    """
    The standard bottleneck residual block used by ResNet-50, 101 and 152
    defined in :paper:`ResNet`.  It contains 3 conv layers with kernels
    1x1, 3x3, 1x1, and a projection shortcut if needed.
    """

    def __init__(
        self,
        in_channels,
        out_channels,
        *,
        bottleneck_channels,
        stride=1,
        num_groups=1,
        norm="BN",
        stride_in_1x1=False,
        dilation=1,
    ):
        """
        Args:
            bottleneck_channels (int): number of output channels for the 3x3
                "bottleneck" conv layers.
            num_groups (int): number of groups for the 3x3 conv layer.
            norm (str or callable): normalization for all conv layers.
                See :func:`layers.get_norm` for supported format.
            stride_in_1x1 (bool): when stride>1, whether to put stride in the
                first 1x1 convolution or the bottleneck 3x3 convolution.
            dilation (int): the dilation rate of the 3x3 conv layer.
        """
        super().__init__(in_channels, out_channels, stride)

        if in_channels != out_channels:
            self.shortcut = Conv2d(
                in_channels,
                out_channels,
                kernel_size=1,
                stride=stride,
                bias=False,
                norm=get_norm(norm, out_channels),
            )
        else:
            self.shortcut = None

        # The original MSRA ResNet models have stride in the first 1x1 conv
        # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have
        # stride in the 3x3 conv
        stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)

        self.conv1 = Conv2d(
            in_channels,
            bottleneck_channels,
            kernel_size=1,
            stride=stride_1x1,
            bias=False,
            norm=get_norm(norm, bottleneck_channels),
        )

        self.conv2 = Conv2d(
            bottleneck_channels,
            bottleneck_channels,
            kernel_size=3,
            stride=stride_3x3,
            padding=1 * dilation,
            bias=False,
            groups=num_groups,
            dilation=dilation,
            norm=get_norm(norm, bottleneck_channels),
        )

        self.conv3 = Conv2d(
            bottleneck_channels,
            out_channels,
            kernel_size=1,
            bias=False,
            norm=get_norm(norm, out_channels),
        )

        for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
            if layer is not None:  # shortcut can be None
                weight_init.c2_msra_fill(layer)

        # Zero-initialize the last normalization in each residual branch,
        # so that at the beginning, the residual branch starts with zeros,
        # and each residual block behaves like an identity.
        # See Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
        # "For BN layers, the learnable scaling coefficient γ is initialized
        # to be 1, except for each residual block's last BN
        # where γ is initialized to be 0."

        # nn.init.constant_(self.conv3.norm.weight, 0)
        # TODO this somehow hurts performance when training GN models from scratch.
        # Add it as an option when we need to use this code to train a backbone.

    def forward(self, x_tuple):
        p = x_tuple[3]
        stat_ls = x_tuple[2]
        mag_ls = x_tuple[1]
        x = x_tuple[0]
        out = self.conv1(x)
        out = F.relu_(out)

        out = self.conv2(out)
        out = F.relu_(out)

        out = self.conv3(out)

        if self.shortcut is not None:
            shortcut = self.shortcut(x)
        else:
            shortcut = x

        out += shortcut

        out = F.relu_(out)
        return [out, mag_ls, stat_ls, p]

class BottleneckBlock_style(CNNBlockBase):
    """
    The standard bottleneck residual block used by ResNet-50, 101 and 152
    defined in :paper:`ResNet`.  It contains 3 conv layers with kernels
    1x1, 3x3, 1x1, and a projection shortcut if needed.
    """

    def __init__(
        self,
        in_channels,
        out_channels,
        *,
        bottleneck_channels,
        stride=1,
        num_groups=1,
        norm="BN",
        stride_in_1x1=False,
        dilation=1,
    ):
        """
        Args:
            bottleneck_channels (int): number of output channels for the 3x3
                "bottleneck" conv layers.
            num_groups (int): number of groups for the 3x3 conv layer.
            norm (str or callable): normalization for all conv layers.
                See :func:`layers.get_norm` for supported format.
            stride_in_1x1 (bool): when stride>1, whether to put stride in the
                first 1x1 convolution or the bottleneck 3x3 convolution.
            dilation (int): the dilation rate of the 3x3 conv layer.
        """
        super().__init__(in_channels, out_channels, stride)

        if in_channels != out_channels:
            self.shortcut = Conv2d(
                in_channels,
                out_channels,
                kernel_size=1,
                stride=stride,
                bias=False,
                norm=get_norm(norm, out_channels),
            )
        else:
            self.shortcut = None

        # The original MSRA ResNet models have stride in the first 1x1 conv
        # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have
        # stride in the 3x3 conv
        stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)

        self.conv1 = Conv2d(
            in_channels,
            bottleneck_channels,
            kernel_size=1,
            stride=stride_1x1,
            bias=False,
            norm=get_norm(norm, bottleneck_channels),
        )

        self.conv2 = Conv2d(
            bottleneck_channels,
            bottleneck_channels,
            kernel_size=3,
            stride=stride_3x3,
            padding=1 * dilation,
            bias=False,
            groups=num_groups,
            dilation=dilation,
            norm=get_norm(norm, bottleneck_channels),
        )

        self.conv3 = Conv2d(
            bottleneck_channels,
            out_channels,
            kernel_size=1,
            bias=False,
            norm=get_norm(norm, out_channels),
        )

        for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
            if layer is not None:  # shortcut can be None
                weight_init.c2_msra_fill(layer)

        #self.ese = eSEModule(out_channels)
        # Zero-initialize the last normalization in each residual branch,
        # so that at the beginning, the residual branch starts with zeros,
        # and each residual block behaves like an identity.
        # See Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
        # "For BN layers, the learnable scaling coefficient γ is initialized
        # to be 1, except for each residual block's last BN
        # where γ is initialized to be 0."

        # nn.init.constant_(self.conv3.norm.weight, 0)
        # TODO this somehow hurts performance when training GN models from scratch.
        # Add it as an option when we need to use this code to train a backbone.

    def forward(self, x_tuple):
        p = x_tuple[3]
        stat_ls = x_tuple[2]
        mag_ls = x_tuple[1]
        x = x_tuple[0]

        out = self.conv1(x)
        out = F.relu_(out)

        out = self.conv2(out)
        out = F.relu_(out)

        out = self.conv3(out)

        if self.shortcut is not None:
            shortcut = self.shortcut(x)
        else:
            shortcut = x

        out += shortcut

        out = F.relu_(out)

        if p > 0.5 and self.training:
            # out = instance_norm_aug(out)
            # out = Normalization_Perturbation_Plus(out)
            out1 = Normalization_Perturbation(out)
            out2 = Normalization_Perturbation1(out)
            out = 0.5 * out1 + 0.5 * out2
            if p > 0.7:  # 第二个条件分支，嵌套在第一个内部，并且依赖同一个 p
                current_std_dev = random.uniform(0.05, 0.15) 
                out = Gaussian_Noise_Perturbation(out, std_dev=current_std_dev) # 应用高斯噪声
                out = Block_Occlusion_Perturbation(out)  
                
        return [out, mag_ls, stat_ls, p]


class DeformBottleneckBlock(CNNBlockBase):
    """
    Similar to :class:`BottleneckBlock`, but with :paper:`deformable conv <deformconv>`
    in the 3x3 convolution.
    """

    def __init__(
        self,
        in_channels,
        out_channels,
        *,
        bottleneck_channels,
        stride=1,
        num_groups=1,
        norm="BN",
        stride_in_1x1=False,
        dilation=1,
        deform_modulated=False,
        deform_num_groups=1,
    ):
        super().__init__(in_channels, out_channels, stride)
        self.deform_modulated = deform_modulated

        if in_channels != out_channels:
            self.shortcut = Conv2d(
                in_channels,
                out_channels,
                kernel_size=1,
                stride=stride,
                bias=False,
                norm=get_norm(norm, out_channels),
            )
        else:
            self.shortcut = None

        stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)

        self.conv1 = Conv2d(
            in_channels,
            bottleneck_channels,
            kernel_size=1,
            stride=stride_1x1,
            bias=False,
            norm=get_norm(norm, bottleneck_channels),
        )

        if deform_modulated:
            deform_conv_op = ModulatedDeformConv
            # offset channels are 2 or 3 (if with modulated) * kernel_size * kernel_size
            offset_channels = 27
        else:
            deform_conv_op = DeformConv
            offset_channels = 18

        self.conv2_offset = Conv2d(
            bottleneck_channels,
            offset_channels * deform_num_groups,
            kernel_size=3,
            stride=stride_3x3,
            padding=1 * dilation,
            dilation=dilation,
        )
        self.conv2 = deform_conv_op(
            bottleneck_channels,
            bottleneck_channels,
            kernel_size=3,
            stride=stride_3x3,
            padding=1 * dilation,
            bias=False,
            groups=num_groups,
            dilation=dilation,
            deformable_groups=deform_num_groups,
            norm=get_norm(norm, bottleneck_channels),
        )

        self.conv3 = Conv2d(
            bottleneck_channels,
            out_channels,
            kernel_size=1,
            bias=False,
            norm=get_norm(norm, out_channels),
        )

        for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
            if layer is not None:  # shortcut can be None
                weight_init.c2_msra_fill(layer)

        nn.init.constant_(self.conv2_offset.weight, 0)
        nn.init.constant_(self.conv2_offset.bias, 0)

    def forward(self, x):
        out = self.conv1(x)
        out = F.relu_(out)

        if self.deform_modulated:
            offset_mask = self.conv2_offset(out)
            offset_x, offset_y, mask = torch.chunk(offset_mask, 3, dim=1)
            offset = torch.cat((offset_x, offset_y), dim=1)
            mask = mask.sigmoid()
            out = self.conv2(out, offset, mask)
        else:
            offset = self.conv2_offset(out)
            out = self.conv2(out, offset)
        out = F.relu_(out)

        out = self.conv3(out)

        if self.shortcut is not None:
            shortcut = self.shortcut(x)
        else:
            shortcut = x

        out += shortcut
        out = F.relu_(out)
        return out


class BasicStem(CNNBlockBase):
    """
    The standard ResNet stem (layers before the first residual block),
    with a conv, relu and max_pool.
    """

    def __init__(self, in_channels=3, out_channels=64, norm="BN"):
        """
        Args:
            norm (str or callable): norm after the first conv layer.
                See :func:`layers.get_norm` for supported format.
        """
        super().__init__(in_channels, out_channels, 4)
        self.in_channels = in_channels
        self.conv1 = Conv2d(
            in_channels,
            out_channels,
            kernel_size=7,
            stride=2,
            padding=3,
            bias=False,
            norm=get_norm(norm, out_channels),
        )
        weight_init.c2_msra_fill(self.conv1)
        #self.ese = eSEModule(out_channels)

    def forward(self, x_tuple):
        p = x_tuple[3]
        stat_ls = x_tuple[2]
        mag_ls = x_tuple[1]
        x = x_tuple[0]
        x = self.conv1(x)

        '''
        mean, std = calc_ins_mean_std(x, eps=1e-12)
        x, mag = self.ese(x)
        mag_ls.append(mag)
        stat_ls.append([mean.squeeze(-1).squeeze(-1), std.squeeze(-1).squeeze(-1)])
        '''
        x = F.relu_(x)

        if p > 0.5 and self.training:
            # x = instance_norm_aug(x)
            x1 = Normalization_Perturbation(x)
            x2 = Normalization_Perturbation1(x)
            x = 0.5 * x1 + 0.5 * x2

        x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)

        return [x, mag_ls, stat_ls, p]

class Cbam_module(nn.Module):
    def __init__(self, kernel_size=7):
        super(Cbam_module, self).__init__()

        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)  # 7,3     3,1
        self.sigmoid = nn.Sigmoid()
        weight_init.c2_msra_fill(self.conv1)

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)
    
# class SpectrumAttention(nn.Module):
#     def __init__(self, channel):
#         super(SpectrumAttention, self).__init__()
#         self.conv1 = nn.Conv2d(channel, channel, 1)
#         self.bn1 = nn.BatchNorm2d(channel)
#         self.relu = nn.ReLU(inplace=True)
#         self.attention_module1 = Cbam_module()
#         # weight_init.c2_msra_fill(self.conv1)
#         # self.conv2 = nn.Conv2d(channel, channel, 1)
#         # self.bn2 = nn.BatchNorm2d(channel)
#         # self.attention_module2 = Cbam_module()

#     def forward(self, x):
#         x = torch.fft.rfft2(x, dim=(2, 3), norm='ortho')
#         magnitude = torch.abs(x)
#         angle = torch.angle(x)  
#         magnitude = self.conv1(magnitude)
#         magnitude = self.relu(self.bn1(magnitude))
#         identity = magnitude
#         if self.training == True:
#             attention_weight = self.attention_module1(magnitude)
#             out = attention_weight * identity
#         else:
#             # pdb.set_trace()
#             out = identity
#         # 保存mask图像    
#         # cv2.imwrite('mask.png',(attention_weight.squeeze().cpu().numpy() * 255).astype(np.uint8))
#         # pdb.set_trace()
#         out_map = out * torch.exp(1j * angle)
#         out_img = torch.fft.irfft2(out_map, dim=(2, 3), norm='ortho')
#         return out_img

class SpectrumAttention(nn.Module):
    def __init__(self, channel, num_bands=3, reduction=4, use_spatial_context=True):
        super(SpectrumAttention, self).__init__()
        self.channel = channel # 存储 channel 以便门控网络使用
        self.num_bands = num_bands
        self.reduction = reduction
        self.use_spatial_context = use_spatial_context

        # 1. 振幅谱预处理卷积层
        self.conv1 = nn.Conv2d(channel, channel, 1)
        self.bn1 = nn.BatchNorm2d(channel)
        self.relu = nn.ReLU(inplace=True)

        # 2. 用于学习频带注意力权重的MLP
        #    输入维度计算：每个频带 (平均值 + 最大值) 特征 -> channel * num_bands * 2
        mlp_input_dim = channel * num_bands * 2
        if self.use_spatial_context:
            mlp_input_dim += channel # 如果使用空间上下文，则加上其维度

        self.attention_mlp = nn.Sequential( # 更明确的命名
            nn.Linear(mlp_input_dim, mlp_input_dim // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(mlp_input_dim // reduction, channel * num_bands), # 输出仍是为每个通道每个频带生成权重
        )

        # 3. (方案一核心) 用于控制频域处理分支贡献的门控网络
        #    输入是全局空间特征，输出是每个通道的门控值 (0-1)
        #    可以根据channel大小调整隐藏层维度，这里用一个较小的固定比例或值
        gate_hidden_dim = max(16, channel // 4) if channel // 4 > 0 else channel
        if channel == 1 and gate_hidden_dim == 0 : gate_hidden_dim = 1 # 确保隐藏层维度至少为1

        self.frequency_path_gate_mlp = nn.Sequential(
            nn.Linear(channel, gate_hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(gate_hidden_dim, channel),
            nn.Sigmoid() # 输出0-1之间的门控值
        )

    def _generate_frequency_band_masks(self, H_f, W_f, device):
        """
        动态生成 num_bands 个径向划分频谱的掩码。
        掩码在指定的 device 上创建。
        输出形状: [num_bands, 1, H_f, W_f]
        """
        masks = torch.zeros(self.num_bands, H_f, W_f, device=device)
        
        y_coords = torch.arange(H_f, device=device).float()
        x_coords = torch.arange(W_f, device=device).float()

        y_folded_coords = y_coords.clone()
        if H_f > 0: # 避免 H_f = 0 时索引错误
            y_folded_coords[y_coords > H_f // 2] = H_f - y_folded_coords[y_coords > H_f // 2]
        
        dist_sq = y_folded_coords.unsqueeze(1)**2 + x_coords.unsqueeze(0)**2
        
        if H_f == 0 or W_f == 0: # 处理退化情况
             max_dist_val = 0.0
        else:
             max_dist_val = torch.sqrt(torch.tensor((H_f / 2.0)**2 + (W_f - 1.0)**2, device=device)).item()

        radii = torch.linspace(0, max_dist_val, self.num_bands + 1, device=device)
        
        for i in range(self.num_bands):
            r_low_sq = radii[i]**2
            r_high_sq = radii[i+1]**2

            if i == 0 : # 第一个频带包含DC，使用 >= r_low_sq
                band_mask_i = (dist_sq >= r_low_sq) & (dist_sq <= r_high_sq)
            else: # 其他频带 > r_low_sq 确保不重叠
                band_mask_i = (dist_sq > r_low_sq) & (dist_sq <= r_high_sq)

            # 确保最后一个频带覆盖到最远端
            if i == self.num_bands - 1: # 如果是最后一个频带
                if i == 0: # 如果只有一个频带，则使用 >= r_low_sq
                    band_mask_i = (dist_sq >= r_low_sq)
                else: # 如果有多个频带，最后一个频带 > r_low_sq
                    band_mask_i = (dist_sq > r_low_sq)
                
            masks[i] = band_mask_i.float()

        for i in range(self.num_bands): # 检查是否有空频带
            if masks[i].sum() == 0:
                if i == 0 and H_f > 0 and W_f > 0: # 至少保证DC分量在第一个频带
                    masks[i, 0, 0] = 1.0
        return masks.unsqueeze(1) # 返回形状: [num_bands, 1, H_f, W_f]

    def forward(self, x_spatial):
        B, C, H_orig, W_orig = x_spatial.shape
        
        H_f = H_orig
        W_f = W_orig // 2 + 1 # rfft2 输出的宽度

        # --- 步骤1: FFT ---
        x_fft = torch.fft.rfft2(x_spatial, dim=(2, 3), norm='ortho')
        magnitude = torch.abs(x_fft)
        angle = torch.angle(x_fft)

        # --- 步骤2: 振幅谱预处理 ---
        processed_magnitude = self.relu(self.bn1(self.conv1(magnitude))) # [B, C, H_f, W_f]

        # --- 步骤3: 生成频带掩码 ---
        band_masks = self._generate_frequency_band_masks(H_f, W_f, device=x_spatial.device) # [Nb, 1, H_f, W_f]

        # --- 步骤4: 提取丰富的频带特征 (平均值 + 最大值) ---
        masked_all = processed_magnitude.unsqueeze(1) * band_masks # [B, Nb, C, H_f, W_f]
        
        sum_mag_in_bands = torch.sum(masked_all, dim=(3, 4)) # [B, Nb, C]
        num_elements_in_bands = torch.sum(band_masks, dim=(2, 3)) + 1e-12 # [Nb, 1]
        avg_mag_in_bands = sum_mag_in_bands / num_elements_in_bands.squeeze(-1).unsqueeze(0).unsqueeze(-1) # [B, Nb, C]
        
        max_mag_in_bands = torch.amax(masked_all, dim=(3, 4)) # [B, Nb, C]

        band_features_rich = torch.cat([avg_mag_in_bands, max_mag_in_bands], dim=2) # [B, Nb, C*2]
        concatenated_band_features = band_features_rich.transpose(1, 2).contiguous().view(B, C * self.num_bands * 2) # [B, C*Nb*2]

        # --- 步骤5: (可选) 引入空间全局上下文 ---
        if self.use_spatial_context:
            spatial_summary = F.adaptive_avg_pool2d(x_spatial, 1).view(B, C) # [B, C]
            final_summary_for_mlp = torch.cat([concatenated_band_features, spatial_summary], dim=1)
        else:
            final_summary_for_mlp = concatenated_band_features
            
        # --- 步骤6: 学习频带注意力权重 ---
        mlp_out = self.attention_mlp(final_summary_for_mlp) # [B, C*Nb]
        band_channel_att_scores = torch.sigmoid(mlp_out.view(B, C, self.num_bands)) # [B, C, Nb]

        # --- 步骤7: 应用注意力到振幅谱 ---
        scores_expanded = band_channel_att_scores.permute(0, 2, 1).unsqueeze(-1).unsqueeze(-1) # [B, Nb, C, 1, 1]
        # masked_all 形状 [B, Nb, C, H_f, W_f] (包含了 pm * mask_per_band)
        attended_contributions = scores_expanded * masked_all
        attended_magnitude = torch.sum(attended_contributions, dim=1) # [B, C, H_f, W_f]
            
        # --- 步骤8: IFFT ---
        out_fft_map = attended_magnitude * torch.exp(1j * angle)
        out_img_freq_processed = torch.fft.irfft2(out_fft_map, s=(H_orig, W_orig), dim=(2, 3), norm='ortho')
        
        # --- 步骤9: (方案一核心) 应用学习到的门控控制频域分支的贡献 ---
        spatial_summary_for_gate = F.adaptive_avg_pool2d(x_spatial, 1).view(B, C)
        gate_values = self.frequency_path_gate_mlp(spatial_summary_for_gate) # [B, C]
        
        gate_values = gate_values.unsqueeze(-1).unsqueeze(-1) # 扩展为 [B, C, 1, 1] 以便广播

        return x_spatial + gate_values * out_img_freq_processed
    
class ResNet(Backbone):
    """
    Implement :paper:`ResNet`.
    """

    def __init__(self, stem, stages, num_classes=None, out_features=None, freeze_at=0):
        """
        Args:
            stem (nn.Module): a stem module
            stages (list[list[CNNBlockBase]]): several (typically 4) stages,
                each contains multiple :class:`CNNBlockBase`.
            num_classes (None or int): if None, will not perform classification.
                Otherwise, will create a linear layer.
            out_features (list[str]): name of the layers whose outputs should
                be returned in forward. Can be anything in "stem", "linear", or "res2" ...
                If None, will return the output of the last layer.
            freeze_at (int): The number of stages at the beginning to freeze.
                see :meth:`freeze` for detailed explanation.
        """
        super().__init__()
        self.stem = stem
        self.num_classes = num_classes

        current_stride = self.stem.stride
        self._out_feature_strides = {"stem": current_stride}
        self._out_feature_channels = {"stem": self.stem.out_channels}
        
        
        #xxxxxxxxxxxxxxxxxxxxxxxxxxxx
        self.spec = SpectrumAttention(channel=256)
        # self.amplitude_bank = [None,None,None,None]
        #xxxxxxxxxxxxxxxxxxxxxxxxxxxx
        
        
        self.stage_names, self.stages = [], []

        if out_features is not None:
            # Avoid keeping unused layers in this module. They consume extra memory
            # and may cause allreduce to fail
            num_stages = max(
                [{"res2": 1, "res3": 2, "res4": 3, "res5": 4}.get(f, 0) for f in out_features]
            )
            stages = stages[:num_stages]
        for i, blocks in enumerate(stages):
            assert len(blocks) > 0, len(blocks)
            for block in blocks:
                assert isinstance(block, CNNBlockBase), block

            name = "res" + str(i + 2)
            stage = nn.Sequential(*blocks)

            self.add_module(name, stage)
            self.stage_names.append(name)
            self.stages.append(stage)

            self._out_feature_strides[name] = current_stride = int(
                current_stride * np.prod([k.stride for k in blocks])
            )
            self._out_feature_channels[name] = curr_channels = blocks[-1].out_channels
        self.stage_names = tuple(self.stage_names)  # Make it static for scripting

        if num_classes is not None:
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
            self.linear = nn.Linear(curr_channels, num_classes)

            # Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
            # "The 1000-way fully-connected layer is initialized by
            # drawing weights from a zero-mean Gaussian with standard deviation of 0.01."
            nn.init.normal_(self.linear.weight, std=0.01)
            name = "linear"

        if out_features is None:
            out_features = [name]
        self._out_features = out_features
        assert len(self._out_features)
        children = [x[0] for x in self.named_children()]
        for out_feature in self._out_features:
            assert out_feature in children, "Available children: {}".format(", ".join(children))
        self.freeze(freeze_at)
        
    def amplitude_drop(self, feature, layer_index, p, percent, layer_dropout_flag, m):
        if self.amplitude_bank[layer_index] == None:
            if p <= 0.5:
                feature_mean = torch.mean(feature, dim=0)
                self.amplitude_bank[layer_index] = feature_mean
        else:
            if p <= 0.5:
                feature_mean = torch.mean(feature, dim=0)
                if self.amplitude_bank[layer_index].shape != feature_mean.shape:
                    feature_mean = F.interpolate(feature_mean.unsqueeze(0), size=(self.amplitude_bank[layer_index].shape[1],self.amplitude_bank[layer_index].shape[2]), mode='bilinear', align_corners=True).squeeze(0)
                    self.amplitude_bank[layer_index] = self.amplitude_bank[layer_index] * 0.8 + feature_mean * 0.2
                return feature
            else:
                if layer_dropout_flag:
                    if self.amplitude_bank[layer_index].shape[1] != feature.shape[2] or self.amplitude_bank[layer_index].shape[2] != feature.shape[3]:
                        # pdb.set_trace()
                        amplitude_bank_x = F.interpolate(self.amplitude_bank[layer_index].unsqueeze(0), size=(feature.shape[2], feature.shape[3]), mode='bilinear', align_corners=True).squeeze(0)
                        diff = feature - amplitude_bank_x
                    else:
                        diff = feature - self.amplitude_bank[layer_index]
                    diff_std = torch.std(diff, dim=(2,3))
                    # pdb.set_trace()
                    top_values, top_indices = torch.topk(diff_std, k= int(percent * diff_std.shape[1]), dim=1)
                    m = int(m * diff_std.shape[1])
                    # 使用 torch.randperm 获取列的随机索引
                    random_column_indices = torch.randperm(top_indices.size(1))[:m]

                    # 选择随机列对应的数据
                    top_indices = top_indices[:, random_column_indices]
                    # pdb.set_trace()
                    # pdb.set_trace()
                    mask_filter = torch.zeros((feature.shape[0], feature.shape[1])).to(feature.device)
                    for i, mask in enumerate(mask_filter):
                        mask[top_indices[i]] = 1
                    # x = torch.fft.rfft2(feature, dim=(2, 3), norm='ortho')
                    # magnitude = torch.abs(x)
                    # angle = torch.angle(x) 
                    # pdb.set_trace()
                    domain_variant_amplitude = feature * mask_filter.unsqueeze(-1).unsqueeze(-1)
                    domain_invariant_amplitude = feature - domain_variant_amplitude
                    # domain_variant_amplitude_new = torch.mean(domain_invariant_amplitude, dim=1).unsqueeze(1).repeat(1,feature.shape[1],1,1)
                    # pdb.set_trace()
                    domain_variant_amplitude_new = Normalization_Perturbation(domain_variant_amplitude)
                    feature = domain_invariant_amplitude + domain_variant_amplitude_new
                    # feature = domain_invariant_amplitude
                    # new_spectrum = new_magnitude * torch.exp(1j * angle)
                    # feature = torch.fft.irfft2(new_spectrum, dim=(2, 3), norm='ortho')
                    # if feature.shape[1] == 2048:
                    #     yy,_ = torch.sort(top_indices[0])
                    #     pdb.set_trace()
        return feature

    def forward(self, x):
        """
        Args:
            x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
        Returns:
            dict[str->Tensor]: names and the corresponding features
        """
        assert x.dim() == 4, f"ResNet takes an input of shape (N, C, H, W). Got {x.shape} instead!"
        outputs = {}
        stat_ls = []
        mag_ls = []
        p = random.random()
        x = [x, mag_ls, stat_ls, p]
        x = self.stem(x)
        if "stem" in self._out_features:
            outputs["stem"] = x[0]
            
            
        # 修改的部分---------------------------------
        block = 0
        # p += 0.25
        # pro = random.random()
        # layer_dropout_flag = [0,0,0,0]
        # percent = 0.1
        # m = 0.1
        # if pro < 0.8:
        #     pro0 = random.random()
        #     pro1 = random.random()
        #     pro2 = random.random()
        #     pro3 = random.random()
        #     layer_dropout_flag = [1 if pro0 < 0.5 else 0,1 if pro1 < 0.5 else 0,1 if pro2 < 0.5 else 0,1 if pro3 < 0.5 else 0]
        # -----------------------------------------------------
        
        
        
        for name, stage in zip(self.stage_names, self.stages):
            x = stage(x)
            mag_ls = x[1]
            stat_ls = x[2]
            
            # 修改的部分----------------------------------------
            if block == 0:
                x[0] = self.spec(x[0])
            # if self.training:
            #     x[0] = self.amplitude_drop(x[0], block, p, percent, layer_dropout_flag[block], m)
            block += 1
            # -----------------------------------------------------
            
            
            if name in self._out_features:
                outputs[name] = x[0]
        if self.num_classes is not None:
            x = self.avgpool(x[0])
            x = torch.flatten(x, 1)
            x = self.linear(x)
            if "linear" in self._out_features:
                outputs["linear"] = x
        return outputs, mag_ls, stat_ls

    def output_shape(self):
        return {
            name: ShapeSpec(
                channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
            )
            for name in self._out_features
        }

    def freeze(self, freeze_at=0):
        """
        Freeze the first several stages of the ResNet. Commonly used in
        fine-tuning.
        Layers that produce the same feature map spatial size are defined as one
        "stage" by :paper:`FPN`.
        Args:
            freeze_at (int): number of stages to freeze.
                `1` means freezing the stem. `2` means freezing the stem and
                one residual stage, etc.
        Returns:
            nn.Module: this ResNet itself
        """
        if freeze_at >= 1:
            self.stem.freeze()
        for idx, stage in enumerate(self.stages, start=2):
            if freeze_at >= idx:
                for block in stage.children():
                    block.freeze()
        return self

    @staticmethod
    def make_stage(block_class, num_blocks, *, in_channels, out_channels, **kwargs):
        """
        Create a list of blocks of the same type that forms one ResNet stage.
        Args:
            block_class (type): a subclass of CNNBlockBase that's used to create all blocks in this
                stage. A module of this type must not change spatial resolution of inputs unless its
                stride != 1.
            num_blocks (int): number of blocks in this stage
            in_channels (int): input channels of the entire stage.
            out_channels (int): output channels of **every block** in the stage.
            kwargs: other arguments passed to the constructor of
                `block_class`. If the argument name is "xx_per_block", the
                argument is a list of values to be passed to each block in the
                stage. Otherwise, the same argument is passed to every block
                in the stage.
        Returns:
            list[CNNBlockBase]: a list of block module.
        Examples:
        ::
            stage = ResNet.make_stage(
                BottleneckBlock, 3, in_channels=16, out_channels=64,
                bottleneck_channels=16, num_groups=1,
                stride_per_block=[2, 1, 1],
                dilations_per_block=[1, 1, 2]
            )
        Usually, layers that produce the same feature map spatial size are defined as one
        "stage" (in :paper:`FPN`). Under such definition, ``stride_per_block[1:]`` should
        all be 1.
        """
        blocks = []
        for i in range(num_blocks):
            curr_kwargs = {}
            for k, v in kwargs.items():
                if k.endswith("_per_block"):
                    assert len(v) == num_blocks, (
                        f"Argument '{k}' of make_stage should have the "
                        f"same length as num_blocks={num_blocks}."
                    )
                    newk = k[: -len("_per_block")]
                    assert newk not in kwargs, f"Cannot call make_stage with both {k} and {newk}!"
                    curr_kwargs[newk] = v[i]
                else:
                    curr_kwargs[k] = v

            blocks.append(
                block_class(in_channels=in_channels, out_channels=out_channels, **curr_kwargs)
            )
            in_channels = out_channels
        return blocks

    @staticmethod
    def make_default_stages(depth, block_class=None, **kwargs):
        """
        Created list of ResNet stages from pre-defined depth (one of 18, 34, 50, 101, 152).
        If it doesn't create the ResNet variant you need, please use :meth:`make_stage`
        instead for fine-grained customization.
        Args:
            depth (int): depth of ResNet
            block_class (type): the CNN block class. Has to accept
                `bottleneck_channels` argument for depth > 50.
                By default it is BasicBlock or BottleneckBlock, based on the
                depth.
            kwargs:
                other arguments to pass to `make_stage`. Should not contain
                stride and channels, as they are predefined for each depth.
        Returns:
            list[list[CNNBlockBase]]: modules in all stages; see arguments of
                :class:`ResNet.__init__`.
        """
        num_blocks_per_stage = {
            18: [2, 2, 2, 2],
            34: [3, 4, 6, 3],
            50: [3, 4, 6, 3],
            101: [3, 4, 23, 3],
            152: [3, 8, 36, 3],
        }[depth]
        if block_class is None:
            block_class = BasicBlock if depth < 50 else BottleneckBlock
        if depth < 50:
            in_channels = [64, 64, 128, 256]
            out_channels = [64, 128, 256, 512]
        else:
            in_channels = [64, 256, 512, 1024]
            out_channels = [256, 512, 1024, 2048]
        ret = []
        for (n, s, i, o) in zip(num_blocks_per_stage, [1, 2, 2, 2], in_channels, out_channels):
            if depth >= 50:
                kwargs["bottleneck_channels"] = o // 4
            ret.append(
                ResNet.make_stage(
                    block_class=block_class,
                    num_blocks=n,
                    stride_per_block=[s] + [1] * (n - 1),
                    in_channels=i,
                    out_channels=o,
                    **kwargs,
                )
            )
        return ret


ResNetBlockBase = CNNBlockBase
"""
Alias for backward compatibiltiy.
"""


def make_stage(*args, **kwargs):
    """
    Deprecated alias for backward compatibiltiy.
    """
    return ResNet.make_stage(*args, **kwargs)


@BACKBONE_REGISTRY.register()
def new_build_resnet_backbone(cfg, input_shape):
    """
    Create a ResNet instance from config.
    Returns:
        ResNet: a :class:`ResNet` instance.
    """
    # need registration of new blocks/stems?
    norm = cfg.MODEL.RESNETS.NORM
    stem = BasicStem(
        in_channels=input_shape.channels,
        out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS,
        norm=norm,
    )

    # fmt: off
    freeze_at           = cfg.MODEL.BACKBONE.FREEZE_AT
    out_features        = cfg.MODEL.RESNETS.OUT_FEATURES
    depth               = cfg.MODEL.RESNETS.DEPTH
    num_groups          = cfg.MODEL.RESNETS.NUM_GROUPS
    width_per_group     = cfg.MODEL.RESNETS.WIDTH_PER_GROUP
    bottleneck_channels = num_groups * width_per_group
    in_channels         = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS
    out_channels        = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
    stride_in_1x1       = cfg.MODEL.RESNETS.STRIDE_IN_1X1
    res5_dilation       = cfg.MODEL.RESNETS.RES5_DILATION
    deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE
    deform_modulated    = cfg.MODEL.RESNETS.DEFORM_MODULATED
    deform_num_groups   = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS
    # fmt: on
    assert res5_dilation in {1, 2}, "res5_dilation cannot be {}.".format(res5_dilation)

    num_blocks_per_stage = {
        18: [2, 2, 2, 2],
        34: [3, 4, 6, 3],
        50: [3, 4, 6, 3],
        101: [3, 4, 23, 3],
        152: [3, 8, 36, 3],
    }[depth]

    if depth in [18, 34]:
        assert out_channels == 64, "Must set MODEL.RESNETS.RES2_OUT_CHANNELS = 64 for R18/R34"
        assert not any(
            deform_on_per_stage
        ), "MODEL.RESNETS.DEFORM_ON_PER_STAGE unsupported for R18/R34"
        assert res5_dilation == 1, "Must set MODEL.RESNETS.RES5_DILATION = 1 for R18/R34"
        assert num_groups == 1, "Must set MODEL.RESNETS.NUM_GROUPS = 1 for R18/R34"

    stages = []

    for idx, stage_idx in enumerate(range(2, 6)):
        # res5_dilation is used this way as a convention in R-FCN & Deformable Conv paper
        dilation = res5_dilation if stage_idx == 5 else 1
        first_stride = 1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2
        stage_kargs = {
            "num_blocks": num_blocks_per_stage[idx],
            "stride_per_block": [first_stride] + [1] * (num_blocks_per_stage[idx] - 1),
            "in_channels": in_channels,
            "out_channels": out_channels,
            "norm": norm,
        }
        # Use BasicBlock for R18 and R34.
        if depth in [18, 34]:
            stage_kargs["block_class"] = BasicBlock
        else:
            stage_kargs["bottleneck_channels"] = bottleneck_channels
            stage_kargs["stride_in_1x1"] = stride_in_1x1
            stage_kargs["dilation"] = dilation
            stage_kargs["num_groups"] = num_groups
            if deform_on_per_stage[idx]:
                stage_kargs["block_class"] = DeformBottleneckBlock
                stage_kargs["deform_modulated"] = deform_modulated
                stage_kargs["deform_num_groups"] = deform_num_groups
            elif idx == 0: #idx < 2: #0:
                stage_kargs["block_class"] = BottleneckBlock_style
            else:
                stage_kargs["block_class"] = BottleneckBlock #BottleneckBlock_style
        blocks = ResNet.make_stage(**stage_kargs)
        in_channels = out_channels
        out_channels *= 2
        bottleneck_channels *= 2
        stages.append(blocks)
    return ResNet(stem, stages, out_features=out_features, freeze_at=freeze_at)

def calc_ins_mean_std(x, eps=1e-5):
    """extract feature map statistics"""
    # eps is a small value added to the variance to avoid divide-by-zero.
    size = x.size()
    assert (len(size) == 4)
    N, C = size[:2]
    var = x.contiguous().view(N, C, -1).var(dim=2) + eps
    std = var.sqrt().view(N, C, 1, 1)
    mean = x.contiguous().view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
    return mean, std


def instance_norm_mix(content_feat, style_feat):
    """replace content statistics with style statistics"""
    assert (content_feat.size()[:2] == style_feat.size()[:2])
    size = content_feat.size()
    style_mean, style_std = calc_ins_mean_std(style_feat)
    content_mean, content_std = calc_ins_mean_std(content_feat)

    normalized_feat = (content_feat - content_mean.expand(
        size)) / content_std.expand(size)
    return normalized_feat * style_std.expand(size) + style_mean.expand(size)

def instance_norm_aug(feat):
    """replace content statistics with style statistics"""
    size = feat.size()
    mean, std = calc_ins_mean_std(feat.clone().detach())

    mean_diff = torch.std(mean, 0, keepdim=True)
    mean_scale = mean_diff / mean_diff.max()

    # scale_1 = torch.normal(torch.ones_like(mean), 0.75 * torch.ones_like(mean))
    # scale_2 = torch.normal(torch.ones_like(mean), 0.75 * mean_scale * torch.ones_like(mean)) #- scale_1
    scale_1 = torch.rand(mean.shape, device=mean.device) * 2
    scale_2 = scale_1 - torch.rand(mean.shape, device=mean.device) * 2

    output = scale_1 * feat - scale_1 * mean + scale_2 * mean
    return output

def Normalization_Perturbation_Plus(feat):
    feat_mean = feat.mean((2, 3), keepdim=True)
    ones_mat = torch.ones_like(feat_mean)
    zeros_mat = torch.zeros_like(feat_mean)
    mean_diff = torch.std(feat_mean, 0, keepdim=True)
    mean_scale = mean_diff / mean_diff.max() * 1.5
    alpha = torch.normal(ones_mat, 0.75 * ones_mat)
    beta = 1 + torch.normal(zeros_mat, 0.75 * ones_mat) * mean_scale
    output = alpha * feat - alpha * feat_mean + beta * feat_mean
    return output

class Hsigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(Hsigmoid, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        return F.relu6(x + 3., inplace=self.inplace) / 6.

def Normalization_Perturbation(feat):
    # feat: input features of size (B, C, H, W)
    # pdb.set_trace()
    feat_mean = feat.mean((2, 3), keepdim=True) # size: B, C, 1, 1
    ones_mat = torch.ones_like(feat_mean)
    alpha = torch.normal(ones_mat, 0.75 * ones_mat) # size: B, C,1,1
    beta = torch.normal(ones_mat, 0.75 * ones_mat) # size: B, C,1,1
    output = alpha * feat - alpha * feat_mean + beta * feat_mean
    return output # size: B, C, H, W

def Normalization_Perturbation1(feat):
    # feat: input features of size (B, C, H, W)
    # pdb.set_trace()
    feat_mean = feat.mean((1, 2, 3), keepdim=True) # size: B, C, 1, 1
    ones_mat = torch.ones_like(feat_mean)
    # alpha = torch.normal(ones_mat, 0.75 * ones_mat) # size: B, C,1,1
    # beta = torch.normal(ones_mat, 0.75 * ones_mat) # size: B, C,1,1
    alpha = torch.rand(ones_mat.shape).to(ones_mat.device) * 2 + 0.1# 0.01-2.01均值分布
    beta = torch.rand(ones_mat.shape).to(ones_mat.device) * 2 + 0.1# size: B, C,1,1
    # pdb.set_trace()
    output = alpha * feat - alpha * feat_mean + beta * feat_mean
    return output # size: B, C, H, W

def Gaussian_Noise_Perturbation(feat, std_dev=0.1):

    noise = torch.randn_like(feat) * std_dev

    return feat + noise

def Block_Occlusion_Perturbation(feat, 
                                 num_blocks_range=(1, 3), 
                                 block_size_min_ratio=0.05, 
                                 block_size_max_ratio=0.15, 
                                 occlusion_value=0.0):
    B, C, H, W = feat.shape
    out_feat = feat.clone() # 对副本进行操作

    # 随机确定要生成的遮挡块数量
    num_blocks = torch.randint(num_blocks_range[0], num_blocks_range[1] + 1, (1,)).item()

    for _ in range(num_blocks):
        # 随机确定遮挡块的高度和宽度
        block_h = int(torch.rand(1).item() * (block_size_max_ratio - block_size_min_ratio) * H + block_size_min_ratio * H)
        block_w = int(torch.rand(1).item() * (block_size_max_ratio - block_size_min_ratio) * W + block_size_min_ratio * W)
        
        # 确保block_h和block_w至少为1，并且不超过H和W
        block_h = max(1, min(block_h, H))
        block_w = max(1, min(block_w, W))

        # 随机确定遮挡块的起始位置
        start_h_max = H - block_h
        start_w_max = W - block_w
        
        if start_h_max < 0 or start_w_max < 0: # 如果块比特征图还大 (理论上不会因为上面有min约束)，跳过
            continue

        start_h = torch.randint(0, start_h_max + 1, (1,)).item()
        start_w = torch.randint(0, start_w_max + 1, (1,)).item()
        
        out_feat[:, :, start_h : start_h + block_h, start_w : start_w + block_w] = occlusion_value
        
    return out_feat

class eSEModule(nn.Module):
    def __init__(self, channel, reduction=4):
        super(eSEModule, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Conv2d(channel*2,channel, kernel_size=1,
                             padding=0)
        self.hsigmoid = Hsigmoid()

    def forward(self, x):
        input = x
        #x = self.avg_pool(x)
        mean, std = calc_ins_mean_std(x)
        x = torch.cat([mean, std], 1)
        ori_x = self.fc(x)
        x = self.hsigmoid(ori_x)
        return input * x, ori_x.squeeze(-1).squeeze(-1)
