import torch
import torch.nn as nn
from sliding_attention import *
from modules import *


class ConformerConvModule(nn.Module):

    """
    Conformer convolution module:
    - LayerNorm -> Transpose -> PointwiseConv -> GLU -> DepthwiseConv -> BN -> Swish -> PointwiseConv -> Dropout
    """
    def __init__(self, in_channels: int, kernel_size: int = 5, expansion_factor: int = 2, dropout_p: float = 0.1):
        super().__init__()
        assert (kernel_size - 1) % 2 == 0, "kernel_size should be odd"
        assert expansion_factor == 2, "here only supports expansion_factor=2"

        self.sequential = nn.Sequential(
            nn.LayerNorm(in_channels),
            Transpose(shape=(1, 2)),  # [B, dim, time]
            PointwiseConv1d(in_channels, in_channels * expansion_factor),
            nn.GLU(dim=1),
            DepthwiseConv1d(in_channels, in_channels, kernel_size, padding=(kernel_size - 1)//2),
            nn.BatchNorm1d(in_channels),
            Swish(),
            PointwiseConv1d(in_channels, in_channels),
            nn.Dropout(dropout_p)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.sequential(x).transpose(1, 2)


class ConformerBlock(nn.Module):
    """
    Conformer block for H/L/Ag with Sliding Attention.
    """
    def __init__(
        self,
        d_model=640,
        dim_ff=1280,
        conv_kernel=5,
        n_heads=10,
        dim_head=64,
        dropout=0.1,
        sigma=1.0,
        device='cuda',
        min_bandwidth=48,
        max_bandwidth=144,
        sliding_steps=3,
        scale: int = 3,
        H_weight: float = 0.5
    ):
        super().__init__()
        self.device = device
        expansion_factor = max(dim_ff // d_model, 4)

        # FFN1
        self.ffn1_H = ResidualConnectionModule(FeedForwardModule(d_model, expansion_factor, dropout), module_factor=0.5)
        self.ffn1_L = ResidualConnectionModule(FeedForwardModule(d_model, expansion_factor, dropout), module_factor=0.5)
        self.ffn1_Ag = ResidualConnectionModule(FeedForwardModule(d_model, expansion_factor, dropout), module_factor=0.5)

        # Sliding Attention + Triple
        self.sliding_attn = SlidingInterAttention(
            embed_dim=d_model,
            num_heads=n_heads,
            min_bandwidth=min_bandwidth,
            max_bandwidth=max_bandwidth,
            steps=sliding_steps,
            scale=scale
        )
        self.triple_attn = TripleInterAttention(self.sliding_attn, H_weight=H_weight)

        # Convolution per chain
        self.conv_H = ResidualConnectionModule(ConformerConvModule(d_model, kernel_size=conv_kernel, dropout_p=dropout))
        self.conv_L = ResidualConnectionModule(ConformerConvModule(d_model, kernel_size=conv_kernel, dropout_p=dropout))
        self.conv_Ag = ResidualConnectionModule(ConformerConvModule(d_model, kernel_size=conv_kernel, dropout_p=dropout))

        # FFN2
        self.ffn2_H = ResidualConnectionModule(FeedForwardModule(d_model, expansion_factor, dropout), module_factor=0.5)
        self.ffn2_L = ResidualConnectionModule(FeedForwardModule(d_model, expansion_factor, dropout), module_factor=0.5)
        self.ffn2_Ag = ResidualConnectionModule(FeedForwardModule(d_model, expansion_factor, dropout), module_factor=0.5)

        # LayerNorm
        self.ln_H = nn.LayerNorm(d_model)
        self.ln_L = nn.LayerNorm(d_model)
        self.ln_Ag = nn.LayerNorm(d_model)

    def forward(self, H, L, Ag, h_pad_mask=None, l_pad_mask=None, ag_pad_mask=None):
        # FFN1
        H_ = self.ffn1_H(H)
        L_ = self.ffn1_L(L)
        Ag_ = self.ffn1_Ag(Ag)

        # Triple Sliding Attention
        out_H, out_L, out_Ag, attn_h_ag, attn_l_ag, combined_att = self.triple_attn(H_, L_, Ag_, h_pad_mask, l_pad_mask, ag_pad_mask)

        # Conv
        out_H = self.conv_H(out_H)
        out_L = self.conv_L(out_L)
        out_Ag = self.conv_Ag(out_Ag)

        # FFN2
        out_H = self.ffn2_H(out_H)
        out_L = self.ffn2_L(out_L)
        out_Ag = self.ffn2_Ag(out_Ag)

        # LayerNorm
        out_H = self.ln_H(out_H)
        out_L = self.ln_L(out_L)
        out_Ag = self.ln_Ag(out_Ag)

        return out_H, out_L, out_Ag, attn_h_ag, attn_l_ag, combined_att


class ConformerEncoder(nn.Module):
    """
    Conformer encoder for H/L/Ag interface prediction with configurable sliding-attention parameters.

    Inputs:
      - H_embed: [B, L_h, D]
      - L_embed: [B, L_l, D]
      - Ag_embed: [B, L_ag, D]
      - pad masks for each chain: boolean [B, L_*] where True = padding

    Outputs:
      - logits_H: [B, L_h, num_classes]
      - logits_L: [B, L_l, num_classes]
      - logits_Ag: [B, L_ag, num_classes]
    """
    def __init__(
        self,
        d_model: int = 640,
        dim_ff: int = 1280,
        n_heads: int = 10,
        dim_head: int = 64,
        conv_kernel: int = 5,
        n_blocks: int = 6,
        dropout: float = 0.1,
        sigma: float = 1.0,
        num_classes: int = 2,
        device: str = 'cuda',
        min_bandwidth: int = 48,
        max_bandwidth: int = 144,
        sliding_steps: int = 3,
        scale: int = 3,
        H_weight: float = 0.5
    ):
        super().__init__()
        self.device = device

        # store sliding params
        self.min_bandwidth = min_bandwidth
        self.max_bandwidth = max_bandwidth
        self.sliding_steps = sliding_steps
        self.scale = scale

        # input projection
        self.input_proj = Linear(d_model, d_model)

        # stack ConformerBlockTri with sliding-att params
        self.blocks = nn.ModuleList([
            ConformerBlock(
                d_model=d_model,
                dim_ff=dim_ff,
                conv_kernel=conv_kernel,
                n_heads=n_heads,
                dim_head=dim_head,
                dropout=dropout,
                sigma=sigma,
                device=device,
                min_bandwidth=min_bandwidth,
                max_bandwidth=max_bandwidth,
                sliding_steps=sliding_steps,
                scale=scale,
                H_weight=H_weight
            ) for _ in range(n_blocks)
        ])

        # final classifier (per-residue)
        self.classifier = Linear(d_model, num_classes)

    def forward(
        self,
        H_embed: torch.Tensor,
        L_embed: torch.Tensor,
        Ag_embed: torch.Tensor,
        h_pad_mask: Optional[torch.Tensor] = None,
        l_pad_mask: Optional[torch.Tensor] = None,
        ag_pad_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Forward pass:
          - project inputs
          - pass through stacked blocks
          - classifier per-residue (raw logits)
        """
        H = self.input_proj(H_embed)  # [B, L_h, D]
        L = self.input_proj(L_embed)  # [B, L_l, D]
        Ag = self.input_proj(Ag_embed)  # [B, L_ag, D]

        for block in self.blocks:
            H, L, Ag, _, _, _ = block(
                H, L, Ag,
                h_pad_mask=h_pad_mask,
                l_pad_mask=l_pad_mask,
                ag_pad_mask=ag_pad_mask
            )

        logits_H = self.classifier(H)  # [B, L_h, num_classes]
        logits_L = self.classifier(L)  # [B, L_l, num_classes]
        logits_Ag = self.classifier(Ag)  # [B, L_ag, num_classes]

        return logits_H, logits_L, logits_Ag