
import torch
import torch.nn.functional as F

class ForwardModel(torch.nn.Module):
    """Forward pass model for MSMV Deformable Conv (inlined from msmv_deformable_conv_forward.py)"""
    def __init__(self):
        super().__init__()

    def forward(self, feat_list, spatial_shapes, sampling_loc, attn_weight):
        num_views = feat_list[0].shape[1]
        output = None
        
        N = sampling_loc.shape[0]
        Q = sampling_loc.shape[1]
        P = sampling_loc.shape[2]
        
        loc_w = sampling_loc[..., 0]
        loc_h = sampling_loc[..., 1]
        loc_v_raw = sampling_loc[..., 2]
        loc_v = torch.round(loc_v_raw * (num_views - 1)).long()
        
        C = feat_list[0].shape[2]
        
        for lvl, feat in enumerate(feat_list):
            H, W = feat.shape[3], feat.shape[4]
            weight = attn_weight[..., lvl]
            
            for v_idx in range(num_views):
                mask = (loc_v == v_idx)
                if not mask.any():
                    continue
                
                feat_v = feat[:, v_idx]
                # Ensure grid is on same device as features
                device = feat_v.device
                grid_x = loc_w.reshape(N, -1, 1, 1).to(device)
                grid_y = loc_h.reshape(N, -1, 1, 1).to(device)
                
                gx = grid_x * 2 - 1
                gy = grid_y * 2 - 1
                grid = torch.cat([gx, gy], dim=-1)
                
                sampled = F.grid_sample(feat_v, grid, align_corners=True)
                sampled = sampled.view(N, C, Q, P).permute(0, 2, 3, 1)
                
                # Ensure all tensors are on same device
                weight_d = weight.to(sampled.device)
                mask_d = mask.to(sampled.device)
                
                term = sampled * weight_d.unsqueeze(-1)
                term = term * mask_d.unsqueeze(-1).float()
                
                if output is None:
                    output = term
                else:
                    output += term
        
        if output is not None:
            return output.sum(dim=2)
        else:
            return torch.zeros(N, Q, C, device=feat_list[0].device)


class Model(torch.nn.Module):
    """Backward pass for MSMV Deformable Conv - computes gradients via autograd"""
    def __init__(self):
        super().__init__()
        self.fwd = ForwardModel()

    def forward(self, grad_output, feat_list, spatial_shapes, sampling_loc, attn_weight):
        """
        Manually Compute Gradient for Inputs given Grad Output.
        """
        feats_grad = []
        for f in feat_list:
            f = f.clone().detach().requires_grad_(True)
            feats_grad.append(f)
            
        loc = sampling_loc.clone().detach().requires_grad_(True)
        w = attn_weight.clone().detach().requires_grad_(True)
        
        out = self.fwd(feats_grad, spatial_shapes, loc, w)
        out.backward(grad_output)
        
        grad_feats = [f.grad for f in feats_grad]
        grad_loc = loc.grad
        grad_w = w.grad
        
        return grad_feats, grad_loc, grad_w

def get_init_inputs():
    return []

def get_inputs():
    N, Q, P = 1, 10, 4
    V = 4
    C = 16
    
    feat_list = []
    for s in [10, 5]:
        feat_list.append(torch.randn(N, V, C, s, s))
    
    spatial_shapes = [(10, 10), (5, 5)]
    sampling_loc = torch.rand(N, Q, P, 3)
    attn_weight = torch.rand(N, Q, P, 2)
    
    grad_output = torch.randn(N, Q, C)
    
    return [grad_output, feat_list, spatial_shapes, sampling_loc, attn_weight]
