
import torch
import torch.nn.functional as F

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, feat_list, spatial_shapes, sampling_loc, attn_weight):
        """
        feat_list: List of Multi-View features? [feat_c2, feat_c3...]
                   Or single tensor? 
                   CUDA takes c2, c3, c4, c5 separately.
        spatial_shapes: [(H2, W2), (H3, W3)...] needed to interpret sampling locs?
        sampling_loc: (N, n_queries, n_points, 3) -> (w, h, view_index)
        attn_weight: (N, n_queries, n_points, n_levels)
        
        Returns:
            output: (N, n_queries, C)
        """
        # Ref Logic:
        # data_col element = sum over levels, points:
        #   p_col loop:
        #     loc_w, loc_h, loc_v
        #     For each level:
        #       sample(feat_level, loc_v, loc_h, loc_w) * weight_level
        
        # 1. Parse Inputs
        # The user provided signature `feat_c2, feat_c3...` implies separate inputs.
        # But `op_eval` usually passes a list.
        # We assume `feat_list` contains [c2, c3, c4, c5] (+ c6 optional).
        
        bs, num_views, h2, w2, channels = 1, 6, 10, 10, 32 # placeholder shapes
        # We need to infer shapes or accept them.
        
        # Let's iterate levels
        
        # Output accumulator
        # data_col matches (N, n_queries, channels) ? 
        # CUDA: `batch_size * num_query * channels`.
        
        # Since logic is element-wise, we can implement it with `grid_sample`.
        # However, `loc_v` makes it trickier: we need to sample from specific view.
        # (N, Views, H, W, C) input.
        
        # We can construct a "Mega Feature Map" by stacking Views?
        # Or batch them?
        # Stack views into Batch dim: (N*Views, C, H, W).
        # We need sampling grid to target the correct batch index.
        # `grid_sample` usually works within batch item. It can't sample across batch.
        # UNLESS we manually index.
        
        # Faster approach:
        # Pre-select features?
        # loc_v depends on query index and point index.
        # (N, Q, P).
        # We can use `gather` to select the view features?
        # But we need interpolation (bilinear) on H, W.
        
        # Approach:
        # 1. Flatten input features to (N*Views, C, H, W).
        # 2. Adjust sampling indices to include View offset?
        #    Wait, `grid_sample` takes (N, H, W, 2). It matches `i`-th sample to `i`-th batch.
        #    We have N*Q*P samples.
        #    We have N original batches.
        #    This mismatch (Many queries per Batch) means standard `grid_sample` requires reshaping.
        #    We treat each (n, q, p) as a sample target.
        #    But input is (N, V, H, W).
        #    We select V for each sample.
        
        #    Efficient trick:
        #    Expand input features? No, too big.
        #    Gather: First gather pixels? No, need bilinear.
        
        #    Correct way: 
        #    Separate samplings by View?
        #    loc_v is float? `round(data_sampling_loc... * (num_views-1))`. Integer view index.
        #    So we have discrete Views.
        #    We can loop over Views!
        #    For each View `v`:
        #       mask = (loc_v == v)
        #       sample features for these points using `grid_sample` on View `v` map.
        #       accumulate.
        
        num_views = feat_list[0].shape[1] # Assume (N, V, H, W, C) or N,V,C,H,W
        # CUDA: `feat_c2_ptr = ... + loc_v * h * w * channels`.
        # Implies layout: (..., Views, H, W, C).
        # But standard PyTorch is N, C, H, W (or N, V, C, H, W).
        # We'll assume (N, V, C, H, W).
        
        output = None
        
        # Determine N, Q
        N = sampling_loc.shape[0]
        Q = sampling_loc.shape[1]
        P = sampling_loc.shape[2]
        
        # sampling_loc: (N, Q, P, 3) -> w, h, v_normalized.
        
        loc_w = sampling_loc[..., 0]
        loc_h = sampling_loc[..., 1]
        loc_v_raw = sampling_loc[..., 2]
        loc_v = torch.round(loc_v_raw * (num_views - 1)).long()
        
        for lvl, feat in enumerate(feat_list):
            # feat: (N, V, C, H, W)
            H, W = feat.shape[3], feat.shape[4]
            C = feat.shape[2]
            
            # weights for this level: (N, Q, P) corresponding col in attn_weight
            weight = attn_weight[..., lvl] # (N, Q, P)
            
            # Loop over views to batch grid_sample
            for v_idx in range(num_views):
                 mask = (loc_v == v_idx) # (N, Q, P)
                 if not mask.any():
                     continue
                     
                 # Select points for this view
                 # We need to preserve (N, Q, P) structure for accumulation?
                 # Or just sparse accum?
                 # Mask based accumulation is easiest.
                 
                 # Prepare Grid
                 # (N, Q, P, 2)
                 # We only care about entries where mask is True.
                 # But grid_sample takes dense.
                 # We can pass all, mask output.
                 
                 # Input: feat[:, v_idx, ...] -> (N, C, H, W)
                 feat_v = feat[:, v_idx]
                 
                 # Grid: (N, Q*P, 1, 2)
                 # Ensure grid is on same device as features
                 device = feat_v.device
                 grid_x = loc_w.reshape(N, -1, 1, 1).to(device) # (N, Q*P, 1, 1)
                 grid_y = loc_h.reshape(N, -1, 1, 1).to(device)
                 
                 # Normalize [0, 1] -> [-1, 1]
                 # Note: DeformAttn usually uses unnorm coords but here `msmv` says [0, 1].
                 # Line 106: "Sampling location in range [0, 1]"
                 # Line 123: `h_im = loc_h * (h_c2 - 1)`.
                 # So inputs are indeed [0, 1].
                 # grid_sample expects [-1, 1].
                 # map [0, 1] -> [-1, 1]: x * 2 - 1.
                 
                 gx = grid_x * 2 - 1
                 gy = grid_y * 2 - 1
                 grid = torch.cat([gx, gy], dim=-1) # (N, QP, 1, 2)
                 
                 sampled = F.grid_sample(feat_v, grid, align_corners=True) # (N, C, QP, 1)
                 sampled = sampled.view(N, C, Q, P).permute(0, 2, 3, 1) # (N, Q, P, C)
                 
                 # Apply View Mask & Weight
                 # output accumulation
                 # (N, Q, P, C) * (N, Q, P, 1) * Mask
                 # Ensure all tensors are on same device
                 weight_d = weight.to(sampled.device)
                 mask_d = mask.to(sampled.device)
                 
                 term = sampled * weight_d.unsqueeze(-1)
                 
                 # Masking
                 term = term * mask_d.unsqueeze(-1).float()
                 
                 if output is None:
                     output = term
                 else:
                     output += term
                     
        # Sum over points P
        # output: (N, Q, P, C) -> (N, Q, C)
        if output is not None:
            return output.sum(dim=2)
        else:
            return torch.zeros(N, Q, C, device=feat_list[0].device)

def get_init_inputs():
    return []

def get_inputs():
    N, Q, P = 1, 10, 4
    V = 4
    C = 16
    
    # Feature levels
    feat_list = []
    for s in [10, 5]:
        feat_list.append(torch.randn(N, V, C, s, s))
        
    spatial_shapes = [(10, 10), (5, 5)]
    
    sampling_loc = torch.rand(N, Q, P, 3)
    attn_weight = torch.rand(N, Q, P, 2)
    
    return [feat_list, spatial_shapes, sampling_loc, attn_weight]
