import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np




class QueryGeneration(nn.Module):
    def __init__(self, num_queries, in_channels, squeeze_ratio=0.25):
        super().__init__()
        self.L = num_queries
        self.C = in_channels
        r = max(1, int(self.C * squeeze_ratio))
        
        self.ch_mlp = nn.Sequential(
            nn.Linear(self.C, r, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(r, self.C, bias=False),
            nn.Sigmoid()
        )

        self.attn_mlp = nn.Sequential(
            nn.Linear(self.C, self.C // 2),
            nn.ReLU(inplace=True),
            nn.Linear(self.C // 2, self.L)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, query, feature):
        """
        query: [L, B, C]
        feature: [B, C, H, W]
        return: [L, B, C]
        """
        B, C, H, W = feature.shape
        assert C == self.C

        # channel-wise attention: [B, C, H, W] → [B, N, C]
        region = feature.flatten(2).permute(0, 2, 1)  # [B, N, C]
        ch_attn = self.ch_mlp(region)  # [B, N, C]

        # attention-based spatial summary: [B, N, C] → [B, L, C]
        attn_logits = self.attn_mlp(region)            # [B, N, L]
        attn_weights = torch.softmax(attn_logits, dim=1)  # [B, N, L]
        attn_weights = attn_weights.permute(0, 2, 1)   # [B, L, N]
        sp_attn = torch.bmm(attn_weights, region)      # [B, L, C]
        sp_attn = self.sigmoid(sp_attn)  # Apply sigmoid to spatial attention

        # Fusion
        gating = sp_attn + ch_attn.mean(dim=1, keepdim=True).expand(-1, self.L, -1)  # [B, L, C]
        
        query = query.unsqueeze(1).expand(-1, B, -1)  # [L, B, C]
        q = query.permute(1, 0, 2)                     # [B, L, C]
        q = q * gating                                 # [B, L, C]
        q = q.permute(1, 0, 2)                         # [L, B, C]
        return q


class DA2Block(nn.Module):
    def __init__(self, d_model, nhead=8, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=nhead, dropout=dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.cross_attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=nhead, dropout=dropout)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, query, key_value):
        q2, _ = self.self_attn(query, query, query)
        query = self.norm1(query + q2)
        kv = key_value.permute(1, 0, 2)
        q2, attn = self.cross_attn(query, kv, kv)
        query = self.norm2(query + q2)
        return query, attn

class DA2(nn.Module):
    def __init__(self, in_channels=768, hidden_dim=768, num_queries=64, num_layers=2, output_dim=4096):
        super().__init__()
        self.proj_conv = nn.Conv2d(in_channels, hidden_dim, kernel_size=1, stride=1, padding='same')
        self.norm_input = nn.LayerNorm(hidden_dim)
        self.query = nn.Parameter(torch.randn(num_queries, hidden_dim))
        self.blocks = nn.ModuleList([DA2Block(hidden_dim) for _ in range(num_layers)])
        self.gene_query = QueryGeneration(num_queries=num_queries, in_channels=hidden_dim)
        self.dim_reduction = nn.Linear(hidden_dim, 128)
        self.queries_adjustment = nn.Linear(num_queries, 32)
        self.hidden_dim = hidden_dim

    def forward(self, x):
        feat, _ = x  # feat: [B, C, H, W]
        B = feat.size(0)
        feat2d = self.proj_conv(feat)                    # [B, hidden_dim, H, W]
        feat_flat = feat2d.flatten(2).permute(0, 2, 1)    # [B, H*W, hidden_dim]
        feat_flat = self.norm_input(feat_flat)           # normalize per token

        q = self.gene_query(self.query, feat2d)                 # refine query
        # q = self.query.unsqueeze(1).expand(-1, B, -1)  # [L, B, C]

        for block in self.blocks:
            q, _ = block(q, feat_flat)
        proj = self.dim_reduction(q)                     # [L, B, 128]
        proj = proj.permute(1, 2, 0)                     # [B, 128, L]
        proj = self.queries_adjustment(proj)             # [B, 128, 32]
        out = proj.flatten(1)                            # [B, 128*32]
        out = F.normalize(out, p=2, dim=1)
        return out

    def visualize_attentionmap(self, x):
      
        feat, _ = x  # feat: [B, C, H, W]
        B = feat.size(0)
        feat2d = self.proj_conv(feat)                    # [B, hidden_dim, H, W]
        feat_flat = feat2d.flatten(2).permute(0, 2, 1)    # [B, H*W, hidden_dim]
        feat_flat = self.norm_input(feat_flat)           # normalize per token
        q = self.ddf(self.query , feat2d)                      # refine query

        attn_maps = []
        for block in self.blocks:
            q, attn = block(q, feat_flat)
            attn_maps.append(attn)

        query = q.permute(1, 0, 2)  # query: [B, num_queries, C]
        num_queries = query.size(1)   # num_queries
        n = int(np.sqrt(num_queries)) # grid 한 변의 길이 (예: 8)
        if n * n != num_queries:
            raise ValueError("num_queries가 완전제곱수가 아니어서 (B, C, H, H)로 reshape할 수 없습니다.")

        query = query.permute(0, 2, 1)  # query: [B, C, num_queries]
        query_map = query.view(B, -1, n, n)  # query_map: [B, C, n, n]

        query_map = F.normalize(query_map, p=2, dim=1) # [B, C, n, n]
        final_attn = attn_maps[-1] # [B, L, N]

        return final_attn #, query_map  #
    
    def visualize_ddf_spatial(self, x):
        feat, _ = x                      # feat: [B, C, H, W]
        B, C, H, W = feat.shape

        feat2d = self.proj_conv(feat)    # [B, hidden_dim, H, W]
        
        region = feat2d.flatten(2).permute(0, 2, 1)  # [B, N, C], N=H*W
        attn_logits = self.ddf.attn_mlp(region)         # [B, N, L]
        attn_weights = torch.softmax(attn_logits, dim=1)  # [B, N, L]
        attn_weights = attn_weights.permute(0, 2, 1)    # [B, L, N]


        return attn_weights
    
    def get_ddf_topk_patches(self, x, k=1):
        feat, _ = x                      # feat: [B, C, H, W]
        B, C, H, W = feat.shape
        # 1) proj_conv → patch tokens
        feat2d = self.proj_conv(feat)    # [B, hidden_dim, H, W]
        region = feat2d.flatten(2).permute(0, 2, 1)  # [B, N, C], N = H*W

        # 2) spatial logits & weights
        attn_logits = self.ddf.attn_mlp(region)       # [B, N, L]
        attn_weights = torch.softmax(attn_logits, dim=1)  # [B, N, L]
        attn_weights = attn_weights.permute(0, 2, 1)    # [B, L, N]

        # 3) top-k indices
        topk_vals, topk_idx = attn_weights.topk(k, dim=-1)  # both: [B, L, k]
        return topk_idx  # patch index in [0..N-1]



if __name__ == '__main__':
    # Define dimensions for the test
    dim_v = 768
    dim_t = 768
    batch_size = 120
    patch_num = 20
    d_model = 256
    num_queries = 64
    num_layers = 2
    output_dim = 4096

    # Create model instance
    model = DA2(in_channels=dim_v, hidden_dim=dim_t, num_queries=num_queries, num_layers=num_layers, output_dim=output_dim)

    # Generate test input data
    patch_num = 16
    F_t = torch.randn(batch_size, dim_v, patch_num, patch_num)
    F_sc = torch.randn(batch_size, patch_num * patch_num)
    print(model((F_t, F_sc)).shape)
