import torch
import torch.nn as nn
import torch.nn.functional as F
from model.cross_attention import Node2BlockCrossAttn, Block2NodeCrossAttn

class BDGTransConvLayer(nn.Module):
    def __init__(
        self,
        dim: int,
        k_blocks: int,
        num_heads: int,
        g_blocks: int = 10,
        hard: bool = True,
        dropout_attn: float = 0.2,
        eps: float = 1e-6,
    ):
        super().__init__()
        self.dim = int(dim)
        self.k = int(k_blocks)
        self.g_blocks = int(g_blocks)
        self.eps = float(eps)

        self.n2b = Node2BlockCrossAttn(
            dim=self.dim,
            num_heads=num_heads,
            dropout=dropout_attn,
            use_ln=True, 
            hard = hard
        )
        self.b2n = Block2NodeCrossAttn(
            dim=self.dim,
            num_heads=num_heads,
            dropout=dropout_attn,
            use_ln=True, 
            hard = hard
        )

    def forward(self, x: torch.Tensor, B: torch.Tensor):
        db, attn_bn = self.n2b(x, B, return_attn=True)
        b_out = B + db 
        dx, attn_nb = self.b2n(x, b_out, return_attn=True)
        x_out = x + dx
        
        return x_out, b_out, attn_bn, attn_nb

class BDGTrans(nn.Module):
    def __init__(
        self,
        in_channels: int,
        hidden_channels: int,
        activation,    
        num_layers: int,
        num_heads: int,
        k_blocks: int,
        dropout: float = 0.0,
        g_blocks: int = 10,
        hard: bool = True,
    ):
        super().__init__()
        self.activation = activation
        self.dropout = float(dropout)
        self.num_layers = int(num_layers)
        self.fc_in = nn.Linear(in_channels, hidden_channels)
        self.ln_in = nn.LayerNorm(hidden_channels)
        
        self.layers = nn.ModuleList([
            BDGTransConvLayer(
                dim=hidden_channels,
                k_blocks=k_blocks,
                num_heads=num_heads,
                g_blocks=g_blocks,
                hard = hard,
            ) 
            for _ in range(self.num_layers)
        ])

    def forward(self, x: torch.Tensor, B_init: torch.Tensor):

        self.all_attn_bN = []   
        self.all_attn_Nb = []   

        x = self.fc_in(x)
        x = self.ln_in(x)
        x = self.activation(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        B_current = B_init

        for i, layer in enumerate(self.layers):
            residual = x
            h, B_next, attn_bN, attn_Nb = layer(x, B_current)
            h = F.dropout(h, p=self.dropout, training=self.training)
            x = residual + h
            B_current = B_next

            self.all_attn_bN.append(attn_bN.detach())
            self.all_attn_Nb.append(attn_Nb.detach())
            
        return x, B_current