import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.models as models
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

# 忽略警告信息
import warnings
warnings.filterwarnings('ignore')


# 平均池化获得重要性
def compute_importance_matrix(tensor):
        N = tensor.shape[1]  # N是tokens的数量  (batch_size, seq_len, feature_dim)  即(b,l,dim)
        importance_matrix = torch.matmul(tensor, tensor.transpose(-1, -2))  # tokens * tokens.T (b,l,dim)*(b,dim,l)=(b,l,l)
        importance_pool = F.avg_pool2d(importance_matrix, (1, N))  # 平均池化 (b,l,l)-->(b,l)
        return importance_pool

class Transformer_QK(nn.Module):
    def __init__(self, d_model):
        super(Transformer_QK, self).__init__()
        self.linear = nn.Linear(d_model, d_model)

    def forward(self, src):
        src = self.linear(src)
        return src

# Transformer Encoder 模块
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead)
        self.linear1 = nn.Linear(d_model, d_model * 4)
        self.linear2 = nn.Linear(d_model * 4, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(0.1)
        self.linear3 = nn.Linear(d_model, d_model)

    def forward(self, src):
        src2, _ = self.self_attn(src, src, src)
        src = src + self.dropout(src2)
        src = self.norm1(src)
        src2 = self.linear2(F.relu(self.linear1(src)))
        src = src + self.dropout(src2)
        src = self.norm2(src)
        src = self.linear3(src)
        return src


# 与所有patch特征进行cross-attention，计算sigma
class CrossAttention(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.query_proj = nn.Linear(embed_dim, embed_dim)
        self.key_proj = nn.Linear(embed_dim, embed_dim)
        self.value_proj = nn.Linear(embed_dim, embed_dim)
        self.fc = nn.Linear(embed_dim, 1)  # 输出sigma

    def forward(self, query, key):
        B, N, C = query.shape
        Q = self.query_proj(query)  # (1, radiation_num, C)
        K = self.key_proj(key)      # (1, N, C)
        V = self.value_proj(key)    # (1, N, C)
        attn_weights = torch.matmul(Q, K.transpose(1, 2)) / (C ** 0.5)  # (1, radiation_num, N)
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_output = torch.matmul(attn_weights, V)  # (1, radiation_num, C)
        sigma = self.fc(attn_output).squeeze(-1)  # (1, radiation_num)
        return sigma
    

class RadiationMaskModule(nn.Module):
    def __init__(self, feature_dim, mask_percentage, hard_mask_percentage=0.6, total_epochs=30, num_heads=8):
        super(RadiationMaskModule, self).__init__()
        # self.encoders = nn.ModuleList([TransformerEncoderLayer(feature_dim, num_heads) for _ in range(4)])
        self.encoders = nn.ModuleList([Transformer_QK(feature_dim) for _ in range(4)])
        self.cross_attn = nn.ModuleList([CrossAttention(feature_dim) for _ in range(4)])
        self.mask_percentage = mask_percentage
        self.hard_mask_percentage = hard_mask_percentage
        self.total_epochs = total_epochs
        
    def forward(self, src_list,  current_epoch):
        masked_src = []
        # fffffffff
        for i, feat in enumerate(src_list):
            device = feat.device
            # 展平操作
            B, C, H, W = feat.shape
            N = H * W
            mask_percentage = self.mask_percentage + i * 0.1
            total_epochs = self.total_epochs
            mask_total_num = int(N * mask_percentage) # tokens需要掩码的总数
            important_percentage = 0.3  # 重要区域的比例
            hard_mask_percentage = self.hard_mask_percentage + i * 0.1
            k = 0.5 * (1 - current_epoch / total_epochs)  # k值随着轮数衰减
            
            # 计算重要性矩阵
            src_flat = feat.view(B, C, -1).permute(0, 2, 1)  # 形状调整为(batch_size, seq_len, feature_dim)
            encoded_tokens = self.encoders[i](src_flat)
            importance_matrix = compute_importance_matrix(encoded_tokens)
            importance = importance_matrix.squeeze().view(B, H, W)
            
            # 展平并排序
            importance_flat = importance.view(B, -1)  # (1, N)
            sorted_importance, indices = torch.sort(importance_flat, dim=1, descending=True)
            indices = indices.to(device=device)

            # 选取重要区域的patch索引
            topk_num = int(N * important_percentage)
            important_indices = indices[:, :topk_num]  # (1, topk_num)

            # 在重要区域内随机选取70%的点作为辐射点 
            radiation_num = int(topk_num * 0.7)
            # 生成每个批次的随机 perm
            perms = torch.rand(B, topk_num).argsort(dim=1)  # [b, topk_num]
            # 选择前 radiation_num 个点
            selected_indices = perms[:, :radiation_num].to(device=device)  # [b, radiation_num]
            radiation_indices = torch.gather(important_indices, 1, selected_indices)  # [b, radiation_num]

            # 计算辐射点的中心坐标和幅值  重要性矩阵
            importance_max = importance_flat.max(dim=1, keepdim=True)[0]  # 计算每个批次的最大值
            importance_norm = importance_flat / (importance_max + 1e-6)  # 归一化
            amplitude = importance_norm[torch.arange(B).unsqueeze(1), radiation_indices]

            # 计算辐射点的坐标（归一化）
            y = (radiation_indices // W).float() / (H - 1)
            x = (radiation_indices % W).float() / (W - 1)
            # coords = torch.stack([x, y], dim=1).unsqueeze(0)  # (1, radiation_num, 2)
            coords = torch.stack([x, y], dim=2) # (b, radiation_num, 2)

            # 从特征图中提取辐射点对应的patch特征作为query
            feat_flat = feat.view(B, C, -1).transpose(1, 2)  # (1, N, C)
            # radiation_feat = feat_flat[:, radiation_indices[0], :]  # (1, radiation_num, C)
            radiation_feat = feat_flat[torch.arange(B).unsqueeze(1), radiation_indices, :]  # (b, radiation_num, C)

            sigma = self.cross_attn[i](radiation_feat, feat_flat)  # (1, radiation_num)
            sigma = F.softplus(sigma) + 1e-6  # 确保sigma为正值

            # 建立高斯辐射场，计算每个patch的辐射强度
            idxs = torch.arange(N, device=device).unsqueeze(0)  # (1, N)
            y_all = (idxs // W).float() / (H - 1)
            x_all = (idxs % W).float() / (W - 1)
            all_coords = torch.stack([x_all, y_all], dim=2)  # (1, N, 2)
            all_coords = all_coords.repeat(B, 1, 1) 

            # 计算辐射强度
            coef = amplitude.unsqueeze(1) / (2 * torch.pi * sigma.unsqueeze(1) ** 2 + 1e-8)  # 避免除以零
            coord_diff = all_coords.unsqueeze(2) - coords.unsqueeze(1)  # (1, N, radiation_num, 2)
            dist_squared = (coord_diff ** 2).sum(dim=-1)  # (1, N, radiation_num)
            gaussian = coef * torch.exp(-dist_squared / (2 * sigma.unsqueeze(1) ** 2 + 1e-8))  # (1, N, radiation_num)
            intensity = gaussian.sum(dim=-1)  # (1, N) 辐射矩阵

            # 假设 intensity 的形状为 [b, N]
            intensity_norm = (intensity - intensity.min(dim=1, keepdim=True)[0]) / (intensity.max(dim=1, keepdim=True)[0] - intensity.min(dim=1, keepdim=True)[0] + 1e-6)
            # 计算每个批次的均值和标准差
            intensity_mean = intensity_norm.mean(dim=1, keepdim=True)  # 计算每行的均值
            intensity_std = intensity_norm.std(dim=1, keepdim=True)    # 计算每行的标准差

            upper_threshold = intensity_mean + (0.5 + k) * intensity_std
            lower_threshold = intensity_mean - (0.5 - k) * intensity_std

            # 初始化掩码矩阵
            mask = torch.zeros(B, N, device=device)
            intensity = intensity_norm
            
            for b in range(B):
                # 找到当前批次的大于上阈值的索引
                hard_mask_indices = (intensity[b] > upper_threshold[b]).nonzero(as_tuple=True)[0]
                m = hard_mask_indices.size(0)
                hard_mask_num = int(m * hard_mask_percentage)
                hard_mask_indices = hard_mask_indices[torch.randperm(m)[:hard_mask_num]]

                # 如果硬掩码数量大于k，则随机选择k个
                if hard_mask_num > mask_total_num:
                    selected_indices = hard_mask_indices[torch.randperm(hard_mask_num)[:mask_total_num]]
                else:
                    selected_indices = hard_mask_indices

                # 更新掩码
                mask[b, selected_indices] = 1.0
            
            soft_mask = (intensity >= lower_threshold) & (intensity <= upper_threshold)  # 生成布尔掩码
            soft_mask_values = torch.zeros_like(intensity)  # 初始化软掩码值
            # 计算软掩码值
            soft_intensity = 1 / (intensity + 1e-6)  # 计算倒数
            soft_intensity = soft_intensity * soft_mask  # 仅保留在阈值范围内的值
            soft_intensity_max = soft_intensity.max(dim=1, keepdim=True)[0]  # 计算每个批次的最大值
            soft_mask_values = soft_intensity / soft_intensity_max  # 归一化
            # 应用软掩码值
            soft_mask_values = (1 - (soft_mask).float()) + soft_mask_values # 仅对满足条件的区域更新


            # 将掩码应用到特征上
            mask = mask.view(B, 1, H, W)
            soft_mask_values = soft_mask_values.view(B, 1, H, W)
            masked_feat = feat * (1 - mask) * soft_mask_values  # 被硬掩码的patch直接为0，软掩码的patch乘以软掩码值
            # # 返回掩码后的特征图和辐射强度（用于可视化）
            # radiation_intensity = intensity_norm.view(B, H, W)
            
            masked_src.append(masked_feat)
            
        return masked_src

    
# 定义一个简单的Swin Transformer特征提取器（使用ResNet替代）
# Swin Transformer需要复杂的配置，这里使用ResNet替代，方便演示
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        model = models.resnet50(pretrained=True)
        self.features = nn.Sequential(*list(model.children())[:-2])  # 去掉最后的平均池化和全连接层

    def forward(self, x):
        x = self.features(x)
        return x  # 输出形状：(B, C, H, W)

# 读取并预处理图片
def load_image(image_path):
    img = Image.open(image_path).convert('RGB')
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])
    img_tensor = transform(img).unsqueeze(0)  # 增加batch维度
    return img_tensor, img

# 计算重要性矩阵，这里使用特征图的L2范数作为重要性
def compute_importance(feature_map):
    # feature_map形状：(B, C, H, W)
    importance = torch.norm(feature_map, p=2, dim=1)  # (B, H, W)
    return importance

# 可视化函数
def visualize_results(img, radiation_intensity, masked_feature_map):
    plt.figure(figsize=(15, 5))

    # 原始图片
    plt.subplot(1, 3, 1)
    plt.imshow(img)
    plt.title('Original Image')
    plt.axis('off')

    # 辐射强度图
    plt.subplot(1, 3, 2)
    intensity = radiation_intensity.squeeze().cpu().numpy()
    plt.imshow(intensity, cmap='jet')
    plt.title('Radiation Intensity')
    plt.colorbar()
    plt.axis('off')

    # 掩码后的特征图可视化
    plt.subplot(1, 3, 3)
    # 将特征图在通道维度上求平均，得到(H, W)的二维图像
    masked_feature = masked_feature_map.squeeze().mean(dim=0).cpu().numpy()
    plt.imshow(masked_feature, cmap='gray')
    plt.title('Masked Feature Map')
    plt.axis('off')

    plt.show()

# 主函数
def main():
    # # 设置形状参数
    # b = 2  # 批量大小
    # c = 3  # 通道数
    # h = 4  # 高度
    # w = 5  # 宽度

    # # 生成随机张量
    # feature_map = torch.rand(b, c, h, w)

    # # 3. 计算重要性矩阵
    # importance = compute_importance(feature_map)  # (B, H, W)

    # # 4. 执行高斯辐射场建模和掩码策略
    # # 为了简单，我们只演示一个特征层的情况
    # current_epoch = 15  # 假设当前训练轮数为15
    # total_epochs = 30
    # mask_percentage = 0.1
    # masked_feature_map, radiation_intensity = dynamic_gaussian_radiation_modeling_single_layer(
    #     feature_map, importance, mask_percentage, current_epoch, total_epochs
    # )

    # # 5. 可视化结果
    # # visualize_results(img, radiation_intensity, masked_feature_map)
    
    # 设置形状参数
    l = 1
    b = 2  # 批量大小
    c = 3  # 通道数
    h = 4  # 高度
    w = 5  # 宽度

    # 生成随机张量
    feature_map =[]
    for i in range(l):
        feature_map.append(torch.rand(b, c, h, w))
        

    # 3. 计算重要性矩阵
    net = RadiationMaskModule(feature_dim=3, mask_percentage=0.1, total_epochs=30, num_heads=1)

    # 4. 执行高斯辐射场建模和掩码策略
    # 为了简单，我们只演示一个特征层的情况
    current_epoch = 15  # 假设当前训练轮数为15
    total_epochs = 30
    mask_percentage = 0.1
    
    masked_feature_map = net(feature_map, current_epoch)
    

    # 5. 可视化结果
    # visualize_results(img, radiation_intensity, masked_feature_map)
        
        
# 定义单层的高斯辐射场建模函数
def dynamic_gaussian_radiation_modeling_single_layer(feat, importance, mask_percentage, current_epoch, total_epochs=30):
    """
    feat: Tensor, 形状为 (B, C, H, W)
    importance: Tensor, 形状为 (B, H, W)
    """
    device = feat.device
    B, C, H, W = feat.shape
    N = H * W
    # mask_percentage = mask_percentage  # 假设tokens需要掩码的比例
    mask_total_num = int(N * mask_percentage) # tokens需要掩码的总数
    important_percentage = 0.3  # 重要区域的比例
    k = 0.5 * (1 - current_epoch / total_epochs)  # k值随着轮数衰减

    # 展平并排序
    importance_flat = importance.view(B, -1)  # (1, N)
    sorted_importance, indices = torch.sort(importance_flat, dim=1, descending=True)

    # 选取重要区域的patch索引
    topk_num = int(N * important_percentage)
    important_indices = indices[:, :topk_num]  # (1, topk_num)

    # 在重要区域内随机选取70%的点作为辐射点 
    radiation_num = int(topk_num * 0.7)
    # 生成每个批次的随机 perm
    perms = torch.rand(B, topk_num).argsort(dim=1)  # [b, topk_num]
    # 选择前 radiation_num 个点
    selected_indices = perms[:, :radiation_num]  # [b, radiation_num]
    radiation_indices = torch.gather(important_indices, 1, selected_indices)  # [b, radiation_num]

    # 计算辐射点的中心坐标和幅值  重要性矩阵
    importance_max = importance_flat.max(dim=1, keepdim=True)[0]  # 计算每个批次的最大值
    importance_norm = importance_flat / (importance_max + 1e-6)  # 归一化
    amplitude = importance_norm[torch.arange(B).unsqueeze(1), radiation_indices]

    # 计算辐射点的坐标（归一化）
    y = (radiation_indices // W).float() / (H - 1)
    x = (radiation_indices % W).float() / (W - 1)
    # coords = torch.stack([x, y], dim=1).unsqueeze(0)  # (1, radiation_num, 2)
    coords = torch.stack([x, y], dim=2) # (b, radiation_num, 2)

    # 从特征图中提取辐射点对应的patch特征作为query
    feat_flat = feat.view(B, C, -1).transpose(1, 2)  # (1, N, C)
    # radiation_feat = feat_flat[:, radiation_indices[0], :]  # (1, radiation_num, C)
    radiation_feat = feat_flat[torch.arange(B).unsqueeze(1), radiation_indices, :]  # (b, radiation_num, C)

    # 与所有patch特征进行cross-attention，计算sigma
    class CrossAttention(nn.Module):
        def __init__(self, embed_dim):
            super().__init__()
            self.query_proj = nn.Linear(embed_dim, embed_dim)
            self.key_proj = nn.Linear(embed_dim, embed_dim)
            self.value_proj = nn.Linear(embed_dim, embed_dim)
            self.fc = nn.Linear(embed_dim, 1)  # 输出sigma

        def forward(self, query, key):
            Q = self.query_proj(query)  # (1, radiation_num, C)
            K = self.key_proj(key)      # (1, N, C)
            V = self.value_proj(key)    # (1, N, C)
            attn_weights = torch.matmul(Q, K.transpose(1, 2)) / (C ** 0.5)  # (1, radiation_num, N)
            attn_weights = F.softmax(attn_weights, dim=-1)
            attn_output = torch.matmul(attn_weights, V)  # (1, radiation_num, C)
            sigma = self.fc(attn_output).squeeze(-1)  # (1, radiation_num)
            return sigma

    cross_attn = CrossAttention(C)
    with torch.no_grad():
        sigma = cross_attn(radiation_feat, feat_flat)  # (1, radiation_num)
        sigma = F.softplus(sigma) + 1e-6  # 确保sigma为正值

    # 建立高斯辐射场，计算每个patch的辐射强度
    idxs = torch.arange(N, device=device).unsqueeze(0)  # (1, N)
    y_all = (idxs // W).float() / (H - 1)
    x_all = (idxs % W).float() / (W - 1)
    all_coords = torch.stack([x_all, y_all], dim=2)  # (1, N, 2)
    all_coords = all_coords.repeat(B, 1, 1) 

    # 计算辐射强度
    coef = amplitude.unsqueeze(1) / (2 * torch.pi * sigma.unsqueeze(1) ** 2 + 1e-8)  # 避免除以零
    coord_diff = all_coords.unsqueeze(2) - coords.unsqueeze(1)  # (1, N, radiation_num, 2)
    dist_squared = (coord_diff ** 2).sum(dim=-1)  # (1, N, radiation_num)
    gaussian = coef * torch.exp(-dist_squared / (2 * sigma.unsqueeze(1) ** 2 + 1e-8))  # (1, N, radiation_num)
    intensity = gaussian.sum(dim=-1)  # (1, N) 辐射矩阵

    # 假设 intensity 的形状为 [b, N]
    intensity_norm = (intensity - intensity.min(dim=1, keepdim=True)[0]) / (intensity.max(dim=1, keepdim=True)[0] - intensity.min(dim=1, keepdim=True)[0] + 1e-6)
    # 计算每个批次的均值和标准差
    intensity_mean = intensity_norm.mean(dim=1, keepdim=True)  # 计算每行的均值
    intensity_std = intensity_norm.std(dim=1, keepdim=True)    # 计算每行的标准差

    upper_threshold = intensity_mean + (0.5 + k) * intensity_std
    lower_threshold = intensity_mean - (0.5 - k) * intensity_std

    # 初始化掩码矩阵
    mask = torch.zeros(B, N, device=device)
    intensity = intensity_norm
    
    for b in range(B):
        # 找到当前批次的大于上阈值的索引
        hard_mask_indices = (intensity[b] > upper_threshold[b]).nonzero(as_tuple=True)[0]
        m = hard_mask_indices.size(0)

        # 如果硬掩码数量大于k，则随机选择k个
        if m > mask_total_num:
            selected_indices = hard_mask_indices[torch.randperm(m)[:mask_total_num]]
        else:
            selected_indices = hard_mask_indices

        # 更新掩码
        mask[b, selected_indices] = 1.0
    
    soft_mask = (intensity >= lower_threshold) & (intensity <= upper_threshold)  # 生成布尔掩码  [0,0,0,1,1,1,0,0,0]
    soft_mask_values = torch.zeros_like(intensity)  # 初始化软掩码值  []
    # 计算软掩码值
    soft_intensity = 1 / (intensity + 1e-6)  # 计算倒数  [0.1 0.4 0.5]   [10 2.5 2]
    soft_intensity = soft_intensity * soft_mask  # 仅保留在阈值范围内的值
    soft_intensity_max = soft_intensity.max(dim=1, keepdim=True)[0]  # 计算每个批次的最大值 10
    soft_mask_values = soft_intensity / soft_intensity_max  # 归一化   [1 0. 0.]
    # 应用软掩码值
    soft_mask_values = (1 - (soft_mask).float()) + soft_mask_values # 仅对满足条件的区域更新  [1]


    # 将掩码应用到特征上
    mask = mask.view(B, 1, H, W)
    soft_mask_values = soft_mask_values.view(B, 1, H, W)
    masked_feat = feat * (1 - mask) * soft_mask_values  # 被硬掩码的patch直接为0，软掩码的patch乘以软掩码值

    # 返回掩码后的特征图和辐射强度（用于可视化）
    radiation_intensity = intensity_norm.view(B, H, W)
    return masked_feat, radiation_intensity

if __name__ == '__main__':
    main()