import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from torch import Tensor
import copy
from typing import Optional, Any, Union, Callable

import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import Module,MultiheadAttention,LayerNorm,Dropout


class EDTformerBlock(nn.Module):
    def __init__(self, d_model, nhead=8, dropout=0.1, norm_first=False):
        super(EDTformerBlock, self).__init__()
        self.norm_first = norm_first

        # Self-Attention
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)

        # Cross-Attention
        self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, query, memory):  # query: [B, Lq, C], memory: [B, Lm, C]
        if self.norm_first:
            query = query + self._self_attention(self.norm1(query))
            query = query + self._cross_attention(self.norm2(query), memory)
        else:
            query = self.norm1(query + self._self_attention(query))
            query = self.norm2(query + self._cross_attention(query, memory))
        return query

    def _self_attention(self, x):
        out, _ = self.self_attn(x, x, x, need_weights=False)
        return self.dropout1(out)

    def _cross_attention(self, x, memory):
        out, attn = self.cross_attn(x, memory, memory, need_weights=True, average_attn_weights=False)
        self.latest_cross_attn = attn  # shape: [B, Lq, Lm]
        return self.dropout2(out)


# 2. EDTformer 전체 구조 리팩토링
class EDTformer(nn.Module):
    def __init__(self, in_channels=768, proj_channels=256, num_queries=64, num_layers=2, output_dim=4096):
        super(EDTformer, self).__init__()

        # Learnable Query
        self.query = nn.Parameter(torch.randn(1, num_queries, in_channels))
        nn.init.normal_(self.query, std=1e-6)

        # Preprocessing projection
        self.proj = nn.Linear(in_channels, in_channels)

        # Stacked Transformer Decoder Blocks
        self.blocks = nn.ModuleList([
            EDTformerBlock(d_model=in_channels, nhead=16, dropout=0.1, norm_first=False)
            for _ in range(num_layers)
        ])

        # Feature projection & L2 normalization
        self.channel_proj = nn.Linear(in_channels, proj_channels)
        self.query_proj = nn.Linear(num_queries, 16)

    def forward(self, x):
        # x: tuple of (features, cls token)
        x, cls_token = x                                 # x: [B, C, H, W], cls_token: [B, C]
        B, C, H, W = x.size()

        x = x.flatten(2).permute(0, 2, 1)                # x: [B, H*W, C]
        cls_token = cls_token.unsqueeze(1)              # [B, 1, C]
        memory = torch.cat([cls_token, x], dim=1)       # memory: [B, 1 + H*W, C]
        memory = self.proj(memory)

        # Expand learnable query
        query = self.query.expand(B, -1, -1)             # query: [B, num_queries, C]

        # Pass through transformer decoder blocks
        for block in self.blocks:
            query = block(query, memory)                # [B, num_queries, C]

        # Projection
        x = self.channel_proj(query)                    # [B, num_queries, proj_channels]
        x = x.permute(0, 2, 1)                          # [B, proj_channels, num_queries]
        x = self.query_proj(x)                          # [B, proj_channels, 16]
        x = x.flatten(1)                                # [B, proj_channels * 16]
        x = F.normalize(x, p=2, dim=1)

        return x



if __name__ == '__main__':
    # Define dimensions for the test
    dim_v = 768
    dim_t = 768
    batch_size = 120
    patch_num = 16
    d_model = 256
    num_queries = 64
    num_layers = 2
    output_dim = 4096

    # Create model instance
    model = EDTformer(in_channels=dim_v, proj_channels=dim_t, num_queries=num_queries, num_layers=num_layers, output_dim=output_dim)

    # Generate test input data
    patch_num = 16
    F_t = torch.randn(batch_size, dim_v, patch_num, patch_num)
    F_sc = torch.randn(batch_size, 768)
    print(model((F_t, F_sc)).shape)