# PLuG Multi‑Scale Deformable Attention
# This file is a drop‑in replacement for the original `Deformable-DETR/models/ops/modules/ms_deform_attn.py`

# https://github.com/fundamentalvision/Deformable-DETR/blob/main/models/ops/modules/ms_deform_attn.py

from __future__ import absolute_import, division, print_function

import math
import warnings

import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.init import constant_, xavier_uniform_

from ..functions import MSDeformAttnFunction



def _is_power_of_2(n: int) -> bool:
    if (not isinstance(n, int)) or (n < 0):
        raise ValueError(f"invalid input for _is_power_of_2: {n} (type: {type(n)})")
    return (n & (n - 1) == 0) and n != 0



class MSDeformAttn(nn.Module):
    """Multi‑Scale Deformable Attention with PLuG.

    Args:
        d_model (int): hidden dimension C
        n_levels (int): number of feature pyramid levels L
        n_heads (int): number of heads H
        n_points (int): sampling points per head per level P
    """

    def __init__(self,
                 d_model: int = 256,
                 n_levels: int = 4,
                 n_heads: int = 8,
                 n_points: int = 4):
        super().__init__()
        if d_model % n_heads != 0:
            raise ValueError("d_model must be divisible by n_heads, but got "
                             f"{d_model} and {n_heads}")

        _d_per_head = d_model // n_heads
        if not _is_power_of_2(_d_per_head):
            warnings.warn("For optimal CUDA efficiency, make each head dim a power of two.")

        self.d_model = d_model
        self.n_levels = n_levels
        self.n_heads = n_heads
        self.n_points = n_points
        self.im2col_step = 64

        self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
        self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
        self.value_proj = nn.Linear(d_model, d_model)
        self.output_proj = nn.Linear(d_model, d_model)

        # ------------------------------------------------------------------ gating branch
        # One *scalar* gate per (head, level).  Shape: [B, Len_q, H * L]
        self.gate_proj = nn.Linear(d_model, n_heads * n_levels)
        self.gate_linear = nn.Linear(1, 2)

        self._reset_parameters()

    def _reset_parameters(self):
        constant_(self.sampling_offsets.weight.data, 0.)
        thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
        grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
        grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) \
            .view(self.n_heads, 1, 1, 2) \
            .repeat(1, self.n_levels, self.n_points, 1)
        for i in range(self.n_points):
            grid_init[:, :, i, :] *= i + 1
        with torch.no_grad():
            self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))

        constant_(self.attention_weights.weight.data, 0.)
        constant_(self.attention_weights.bias.data, 0.)
        xavier_uniform_(self.value_proj.weight.data)
        constant_(self.value_proj.bias.data, 0.)
        xavier_uniform_(self.output_proj.weight.data)
        constant_(self.output_proj.bias.data, 0.)

        # ----- gating branch params  – start as (gate = 0)
        constant_(self.gate_proj.weight.data, 0.)
        constant_(self.gate_proj.bias.data, 0.)
        # gate_linear: weight zeros, bias = [1, 0] so gate ~= 0 at start
        constant_(self.gate_linear.weight.data, 0.)
        with torch.no_grad():
            self.gate_linear.bias.data.zero_()
            self.gate_linear.bias.data[0] = 1.0

    def forward(self,
                query: torch.Tensor,  # (N, Len_q, C)
                reference_points: torch.Tensor,  # (N, Len_q, L, 2|4)
                input_flatten: torch.Tensor,  # (N, Len_in, C)
                input_spatial_shapes: torch.Tensor,  # (L, 2)
                input_level_start_index: torch.Tensor,  # (L,)
                input_padding_mask: torch.Tensor = None):  # (N, Len_in)
        """See original MS‑DeformAttn for argument semantics."""
        N, Len_q, _ = query.shape
        N, Len_in, _ = input_flatten.shape
        assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in

        value = self.value_proj(input_flatten)
        if input_padding_mask is not None:
            value = value.masked_fill(input_padding_mask[..., None], float(0))
        value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)

        sampling_offsets = self.sampling_offsets(query) \
            .view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
        raw_attn = self.attention_weights(query) \
            .view(N, Len_q, self.n_heads, self.n_levels, self.n_points)

        # ------------------------------------------------------------------ GATING BRANCH
        gate_logits = self.gate_proj(query) 
        gate_logits = gate_logits.view(N, Len_q, self.n_heads, self.n_levels, 1) 

        gate_out = self.gate_linear(gate_logits) 
        gateA, gateB = gate_out.chunk(2, dim=-1)
        gate = gateA * gateB  
      
        gate = gate.expand(-1, -1, -1, -1, self.n_points)  
        mod_attn = raw_attn + raw_attn * gate 

        attn_flat = mod_attn.view(N, Len_q, self.n_heads, -1)
        attn_flat = F.softmax(attn_flat, dim=-1)
        attention_weights = attn_flat.view(N, Len_q, self.n_heads, self.n_levels, self.n_points)

        if reference_points.shape[-1] == 2:
            offset_norm = torch.stack([input_spatial_shapes[..., 1],
                                        input_spatial_shapes[..., 0]], -1)  
            sampling_locations = reference_points[:, :, None, :, None, :] \
                + sampling_offsets / offset_norm[None, None, None, :, None, :]
        elif reference_points.shape[-1] == 4:
            sampling_locations = reference_points[:, :, None, :, None, :2] \
                + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
        else:
            raise ValueError("Last dim of reference_points must be 2 or 4, got "
                             f"{reference_points.shape[-1]}")

        output = MSDeformAttnFunction.apply(
            value, input_spatial_shapes, input_level_start_index,
            sampling_locations, attention_weights, self.im2col_step)
        output = self.output_proj(output)
        return output
