import torch
import torch.nn as nn
import torch.nn.functional as F

class Transformer_QK(nn.Module):
    def __init__(self, d_model):
        super(Transformer_QK, self).__init__()
        self.linear = nn.Linear(d_model, d_model)

    def forward(self, src):
        src = self.linear(src)
        return src

# Transformer Encoder 模块
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead)
        self.linear1 = nn.Linear(d_model, d_model * 4)
        self.linear2 = nn.Linear(d_model * 4, d_model)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(0.1)
        self.linear3 = nn.Linear(d_model, d_model)

    def forward(self, src):
        src2, _ = self.self_attn(src, src, src)
        src = src + self.dropout(src2)
        src = self.norm1(src)
        src2 = self.linear2(F.relu(self.linear1(src)))
        src = src + self.dropout(src2)
        src = self.norm2(src)
        src = self.linear3(src)
        return src

# 平均池化获得重要性
def compute_importance_matrix(tensor):
    N = tensor.shape[1]  # N是tokens的数量  (batch_size, seq_len, feature_dim)  即(b,l,dim)
    importance_matrix = torch.matmul(tensor, tensor.transpose(-1, -2))  # tokens * tokens.T (b,l,dim)*(b,dim,l)=(b,l,l)
    importance_pool = F.avg_pool2d(importance_matrix, (1, N))  # 平均池化 (b,l,l)-->(b,l)
    return importance_pool

# 掩码操作，只将要掩码的位置设为0，不改变其他tokens的顺序
def apply_mask(tokens, importance, total_mask_rate, current_important_mask_rate):
    # tokens = tokens.permute(1, 0, 2)
    
    N = tokens.shape[1]  # N为patch的数量
    sorted_indices = torch.argsort(importance, descending=True)   
    
    # 重要区域为前mask_rate%的tokens
    # area_rate = 0.01
    # important_tokens = sorted_indices[:, :int(N * area_rate)]
    # non_important_tokens = sorted_indices[:, int(N * area_rate):]
    important_tokens = sorted_indices[:, :int(N * total_mask_rate)]
    non_important_tokens = sorted_indices[:, int(N * total_mask_rate):]

    # 掩码current_important_mask_rate的重要区域
    num_important_mask = int(current_important_mask_rate * len(important_tokens[1]))
    mask_important = []
    # for l in important_tokens:
    #     mask_important.append(l[torch.randperm(len(l))[:num_important_mask]].unsqueeze(0))
    # mask_important = torch.cat(mask_important, dim=0)
    # for idx, mask in enumerate(mask_important):
    #     # 对这些位置进行掩码，设为0
    #     tokens[idx, mask] = 0
    for idx, mask in enumerate(important_tokens):
        mask_important = mask[torch.randperm(len(mask))[:num_important_mask]]
        tokens[idx, mask_important] = 0


    # 在非重要区域中掩码剩余需要掩码的tokens
    remaining_non_important = int(N * total_mask_rate) - num_important_mask
    mask_non_important = []
    # for l in non_important_tokens:
    #     mask_non_important.append(l[torch.randperm(len(l))[:remaining_non_important]].unsqueeze(0))
    # mask_non_important = torch.cat(mask_non_important, dim=0)
    
    # for idx, mask in enumerate(mask_non_important):
    #     # 对非重要区域进行掩码
    #     tokens[idx, mask] = 0
    
    for idx, mask in enumerate(non_important_tokens):
        mask_non_important = mask[torch.randperm(len(mask))[:remaining_non_important]]
        tokens[idx, mask_non_important] = 0

    return tokens

# Adaptive Masking NoAttention Module
class AdaptiveMaskNoAttnModule(nn.Module):
    def __init__(self, feature_dim, K, important_mask_rate):
        super(AdaptiveMaskNoAttnModule, self).__init__()
        # self.src_layers = src_layers
        self.encoders = nn.ModuleList([Transformer_QK(feature_dim) for _ in range(4)])
        self.K = K  # 初始的K值
        self.important_mask_rate = important_mask_rate
    
    def forward(self, src_list, current_epoch):
        masked_src = []
        for i, src in enumerate(src_list):
            # 展平操作
            B, C, H, W = src.shape
            # src_flat = src.view(B, C, -1).permute(2, 0, 1)  # 形状调整为(seq_len, batch_size, feature_dim)
            src_flat = src.view(B, C, -1).permute(0, 2, 1)  # 形状调整为(batch_size, seq_len, feature_dim)
            
            # Transformer Encoder
            encoded_tokens = self.encoders[i](src_flat)
            
            # 计算重要性矩阵
            importance_matrix = compute_importance_matrix(encoded_tokens)
            importance_vector = importance_matrix.squeeze()

            # 动态增加K
            current_K = self.K + i * 10  # 每层的K值逐渐增加
            # 动态增加
            current_important_mask_rate = (self.important_mask_rate + current_epoch) / 100  # 每epoch的current_important_mask_rate值逐渐增加
            # 掩码tokens
            total_mask_rate = current_K / 100
            masked_tokens = apply_mask(src_flat, importance_vector, total_mask_rate, current_important_mask_rate)
            
            # 恢复原形状
            masked_src.append(masked_tokens.permute(0, 2, 1).view(B, C, H, W))
        
        return masked_src

# Adaptive Masking Module
class AdaptiveMaskModule(nn.Module):
    def __init__(self, feature_dim, K, important_mask_rate, num_heads):
        super(AdaptiveMaskModule, self).__init__()
        # self.src_layers = src_layers
        self.encoders = nn.ModuleList([TransformerEncoderLayer(feature_dim, num_heads) for _ in range(4)])
        self.K = K  # 初始的K值
        self.important_mask_rate = important_mask_rate
    
    def forward(self, src_list, current_epoch):
        masked_src = []
        for i, src in enumerate(src_list):
            # 展平操作
            B, C, H, W = src.shape
            src_flat = src.view(B, C, -1).permute(0, 2, 1)  # 形状调整为(batch_size, seq_len, feature_dim)
            
            # Transformer Encoder
            encoded_tokens = self.encoders[i](src_flat)
            
            # 计算重要性矩阵
            importance_matrix = compute_importance_matrix(encoded_tokens)
            importance_vector = importance_matrix.squeeze()

            # 动态增加K
            current_K = self.K + i * 10  # 每层的K值逐渐增加
            # 动态增加
            current_important_mask_rate = (self.important_mask_rate + current_epoch) / 100  # 每epoch的current_important_mask_rate值逐渐增加
            # 掩码tokens
            total_mask_rate = current_K / 100
            masked_tokens = apply_mask(src_flat, importance_vector, total_mask_rate, current_important_mask_rate)
            
            # 恢复原形状
            masked_src.append(masked_tokens.permute(0, 2, 1).view(B, C, H, W))
        
        return masked_src

# 测试函数
def test():
    # 示例的src张量，4层输入
    # src_layers = [
    #     torch.randn(4, 256, 101, 104),
    #     torch.randn(4, 256, 51, 52),
    #     torch.randn(4, 256, 26, 26),
    #     torch.randn(4, 256, 13, 13)
    # ]
    src_layers = [
        torch.randn(4, 3, 4, 4),
        torch.randn(4, 3, 4, 4),
        torch.randn(4, 3, 4, 4),
        torch.randn(4, 3, 4, 4)
    ]
    
    model = AdaptiveMaskModule(feature_dim=3, K=20, important_mask_rate=60, num_heads=1)
    # model = AdaptiveMaskNoAttnModule(feature_dim=3, K=20, important_mask_rate=60)
    masked_output = model(src_layers, current_epoch=5)  # 当前训练轮数为5
    for i, layer in enumerate(masked_output):
        print(f"Layer {i+1} output shape: {layer.shape}")
        
        
if __name__ == '__main__':
    # 运行测试
    test()