import torch
import torch.nn as nn
import torch.nn.functional as F

class Classification(nn.Module):
    def __init__(self, output_dim=768):
        super(Classification, self).__init__()
        self.global_cnn = SmallCNN()  # 处理全图信息
        self.local_cnn = LocalCNN()   # 处理局部信息（切分图像）
        self.freq_cnn = LargeCNN()    # 处理频率信息

        self.fc1 = nn.Linear(3072, 1536)  # 将三路特征压缩到 768 维
        self.fc2 = nn.Linear(1536, 768)
        # self.abfc3 = nn.Linear(2048, 1024)
        # self.abfc4 = nn.Linear(1024, 768)
        self.gelu = nn.GELU()

    def forward(self, img):
        if img.shape[1] == 3:
            img = img.mean(dim=1, keepdim=True)  # RGB 转灰度

        global_feat = self.global_cnn(img)  # 处理全局信息
        local_feat = self.local_cnn(img)    # 处理局部信息
        freq_feat = self.freq_cnn(get_frequency_domain_tensor(img))  # 处理频率信息

        fused = torch.cat([global_feat, local_feat, freq_feat], dim=1)  # [B, 3072]
        #fused = torch.cat([global_feat, local_feat], dim=1)
        output = self.fc1(fused)  # [B, 768]
        output = self.gelu(output)  # 激活
        output = self.fc2(output)
        return output

class SmallCNN(nn.Module):
    """
    处理全局信息的 CNN
    """
    def __init__(self, output_dim=1024):
        super(SmallCNN, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=1),  # [b, 64, h/2, w/2]
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),  # [b, 128, h/4, w/4]
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),  # [b, 256, h/8, w/8]
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))  # [b, 256, 1, 1]
        )
        self.fc = nn.Linear(256, output_dim)

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

class LocalCNN(nn.Module):
    """
    处理局部信息的 CNN
    """
    def __init__(self, output_dim=1024):
        super(LocalCNN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.fc = nn.Linear(256 * 4, output_dim)  # 4个子图拼接

    def forward(self, x):
        b, c, h, w = x.shape
        h_half, w_half = h // 2, w // 2
        patches = [
            x[:, :, 0:h_half, 0:w_half],  # 左上
            x[:, :, 0:h_half, w_half:w],  # 右上
            x[:, :, h_half:h, 0:w_half],  # 左下
            x[:, :, h_half:h, w_half:w]   # 右下
        ]
        features = [self.conv(patch) for patch in patches]
        features = torch.cat(features, dim=1)  # [b, 1024]
        features = features.view(features.size(0), -1)
        return self.fc(features)

class LargeCNN(nn.Module):
    """
    处理频率信息的 CNN
    """
    def __init__(self, output_dim=1024):
        super(LargeCNN, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.fc = nn.Linear(1024, output_dim)

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

def get_frequency_domain_tensor(image):
    """
    计算频域幅度谱
    """
    if image is None:
        raise ValueError("输入的图像为 None")
    if len(image.shape) != 4 or image.shape[1] != 1:
        raise ValueError(f"输入的图像张量形状应为 [b, 1, h, w]，但得到了 {image.shape}")

    f = torch.fft.fft2(image, dim=(-2, -1))
    fshift = torch.fft.fftshift(f, dim=(-2, -1))
    magnitude_spectrum = torch.abs(fshift)
    magnitude_spectrum = torch.log(magnitude_spectrum + 1)
    return magnitude_spectrum
