
import torch
import torch.nn.functional as F
import math

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, value, spatial_shapes, level_start_index, sampling_locations, attention_weights, im2col_step):
        """
        Multi-Scale Deformable Attention Reference (Pure PyTorch, slow)
        
        Args:
            value: (N, Len_in, n_heads, C)
            spatial_shapes: (n_levels, 2)
            level_start_index: (n_levels,)
            sampling_locations: (N, Len_q, n_heads, n_levels, n_points, 2)
            attention_weights: (N, Len_q, n_heads, n_levels, n_points)
            im2col_step: int (ignored in pure python implementation usually, assuming memory enough)
        """
        bs, _, n_heads, c = value.shape
        _, Len_q, _, n_levels, n_points, _ = sampling_locations.shape
        
        # value splits per level
        value_list = value.split([H_ * W_ for H_, W_ in spatial_shapes], dim=1)
        
        # sampling_grid = 2 * sampling_locations - 1 (for grid_sample which takes [-1, 1], but DeformAttn usually takes [0, 1]?)
        # Standard MultiScaleDeformableAttention uses reference points in [0, 1].
        # sampling_locations are offsets added to reference points?
        # Wait, the input here IS `sampling_locations`.
        
        # Standard implementation of MSDA:
        # sampling_locations are absolute coordinates in [0, 1]? Or unnormalized?
        # Usually checking implementation details.
        # D-DETR CUDA kernel expects [0, 1] relative coordinates if using grid_sample logic, 
        # OR unnormalized if doing explicit bilinear interpolation.
        
        # Looking at CUDA code: `ms_deformable_im2col_cuda`
        # It performs bilinear interpolation.
        # Coordinates seem to be unnormalized?
        # "spatial_size" is passed.
        
        output = torch.zeros(bs, Len_q, n_heads, c, device=value.device, dtype=value.dtype)
        
        for lvl, (H, W) in enumerate(spatial_shapes):
            H, W = int(H), int(W)
            # (N, H*W, n_heads, C) -> (N, C*n_heads, H, W)
            value_l = value_list[lvl].flatten(2).transpose(1, 2).reshape(bs, c*n_heads, H, W)
            
            # locations for this level: (N, Len_q, n_heads, n_points, 2)
            sampling_loc_l = sampling_locations[:, :, :, lvl, :, :]
            
            # attention weights: (N, Len_q, n_heads, n_points)
            attn_weight_l = attention_weights[:, :, :, lvl, :]
            
            # grid_sample expects (N, C, H_out, W_out). We treat queries*heads*points as spatial dims?
            # Or we iterate?
            # Efficient way: view (N, H_out, W_out, 2) where H_out*W_out = Len_q * n_heads * n_points
            
            # sampling_locations are likely in [0, W], [0, H] or [0, 1]?
            # Usually DeformAttn uses unnormalized coordinates in the CUDA kernel (w_px, h_px).
            # But grid_sample uses [-1, 1].
            # Let's verify input assumption. If unnormalized (0..W), we convert to [-1, 1].
            
            grid_l = sampling_loc_l.reshape(bs, -1, 1, 2) # (N, TotalPoints, 1, 2)
            
            # Normalize to [-1, 1]
            # x -> 2*x/W - 1
            # y -> 2*y/H - 1
            grid_normalized = torch.zeros_like(grid_l)
            grid_normalized[..., 0] = 2.0 * grid_l[..., 0] / max(W, 1) - 1.0
            grid_normalized[..., 1] = 2.0 * grid_l[..., 1] / max(H, 1) - 1.0
            
            # Sample (N, C*n_heads, TotalPoints, 1)
            sampling_value_l = F.grid_sample(
                value_l, 
                grid_normalized, 
                mode='bilinear', 
                padding_mode='zeros', 
                align_corners=False
            )
            
            # Reshape back: (N, C, n_heads, Len_q, n_points)
            # value_l was (N, C*n_heads, H, W)
            # sampling_value_l is (N, C*n_heads, TotalPoints, 1)
            # TotalPoints = Len_q * n_heads * n_points
            
            # Wait, grid_sample with C*n_heads channels? Yes.
            # But we want specific heads to match specific points.
            # Current grid_sample samples *all* headers at *all* points? No.
            # We need standard MSDA logic: each head samples its OWN points.
            
            # Logic:
            # We have n_heads.
            # We need to run grid_sample n_heads times? Or use groups?
            # Actually, reshape value to (N*n_heads, C, H, W)
            # Reshape grid to (N*n_heads, Len_q*n_points, 1, 2)
            
            value_l_grouped = value_l.reshape(bs*n_heads, c, H, W)
            grid_normalized_grouped = grid_normalized.reshape(bs, Len_q, n_heads, n_points, 2)\
                                          .permute(0, 2, 1, 3, 4)\
                                          .reshape(bs*n_heads, Len_q*n_points, 1, 2)
            
            sampling_value_l = F.grid_sample(
                value_l_grouped,
                grid_normalized_grouped,
                mode='bilinear',
                padding_mode='zeros',
                align_corners=False
            ) # (N*heads, C, Lq*points, 1)
            
            # Reshape to (N, n_heads, C, Len_q, n_points)
            # Then permute to (N, Len_q, n_heads, n_levels(1), n_points, C)
            # Here we just accumulate.
            # (N, n_heads, C, Len_q, n_points)
            sampling_value_l = sampling_value_l.reshape(bs, n_heads, c, Len_q, n_points)
            
            # Attn weights: (N, Len_q, n_heads, n_points) -> permute to (N, n_heads, 1, Len_q, n_points)?
            # We need (N, n_heads, 1, Len_q, n_points) for broadcasting over C?
            attn_weight_l = attn_weight_l.permute(0, 2, 1, 3).unsqueeze(2) # (N, heads, 1, Lq, points)
            
            # Weighted sum over points
            # Value: (N, heads, C, Lq, points)
            # Weight: (N, heads, 1, Lq, points)
            out_l = (sampling_value_l * attn_weight_l).sum(-1) # (N, heads, C, Lq)
            
            # Accumulate to final output (permute to N, Lq, heads, C)
            output += out_l.permute(0, 3, 1, 2)
            
        # Flatten heads*C to match CUDA output shape: (N, Len_q, n_heads*C)
        return output.flatten(2)

def get_init_inputs():
    return []

def get_inputs():
    bs, n_heads, c = 2, 8, 32
    Len_q = 100 # Query points
    n_levels = 2
    n_points = 4
    
    # Generate random shapes for levels
    H1, W1 = 30, 30
    H2, W2 = 15, 15
    spatial_shapes = torch.tensor([[H1, W1], [H2, W2]], dtype=torch.long)
    level_start_index = torch.tensor([0, H1*W1], dtype=torch.long)
    
    Len_in = H1*W1 + H2*W2
    value = torch.randn(bs, Len_in, n_heads, c)
    
    # Sampling locations: unnormalized coordinates?
    # Usually [0, 1] * shapes? Or just [0, 1]?
    # MSDA usually predicts offsets relative to ref points.
    # Ref points are in [0, 1]. Offsets are unnormalized or small?
    # Let's assume input is in range [0, W], [0, H] because we normalized to [-1, 1] in code.
    # If the kernel expects [0, 1], we adjust the reference code logic.
    # Inspecting default behavior of D-DETR: reference_points are [0, 1]. sampling_locations = ref + offset.
    # So sampling_locations are roughly [0, 1].
    
    # If sampling_locations are [0, 1], then my grid_sample logic `2*x/W - 1` is WRONG.
    # It should be `2*x - 1`.
    # BUT, the CUDA kernel `ms_deform_attn_cuda` calls `ms_deformable_im2col_cuda`.
    # Let's check `ms_deformable_im2col_cuda` impl.
    # Usually it takes `spatial_size` and multiplies: `x = x * W; y = y * H`?
    # Actually, standard D-DETR CUDA kernel implementation treats locations as ABSOLUTE [0, W] coordinates?
    # Wait, in D-DETR repo `ms_deform_attn_function.py`, the python forward logic:
    # `sampling_locations` argument to Function is expected to be...
    # The standard behavior is it takes floats.
    # If the CUDA kernel does manual bilinear interpolation:
    # It usually uses `floor(x)`, `ceil(x)`. This implies unnormalized coordinates (pixels).
    # If inputs were [0, 1], `floor(0.5)` is 0. Data would be collapsed to top-left pixel.
    # So `sampling_locations` MUST be in pixel coordinates (0..W, 0..H).
    
    # Therefore, generate test inputs in pixel range.
    sampling_locations = torch.rand(bs, Len_q, n_heads, n_levels, n_points, 2)
    # Scale to typical sizes (0..30)
    sampling_locations[..., 0] *= 30
    sampling_locations[..., 1] *= 30
    
    attention_weights = torch.rand(bs, Len_q, n_heads, n_levels, n_points)
    # weights usually sum to 1? Not enforced by kernel, but ensures stability.
    attention_weights = F.softmax(attention_weights, dim=-1)
    
    return [value, spatial_shapes, level_start_index, sampling_locations, attention_weights, 64] # im2col_step
