
import torch

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, indices_in, kernel_size, stride, padding, dilation, out_shape, transpose=False):
        """
        indices_in: (N, 4) - [b, z, y, x] (or D dim)
        kernel_size, stride, padding, dilation: lists/tuples of D ints
        out_shape: list/tuple of D ints
        Returns:
            indices_out: (M, 4)
            indice_pairs: (K, 2, L) - Pair mapping for each kernel offset. 
                          pairs[k, 0] = inputs, pairs[k, 1] = outputs
            indice_pair_num: (K,) - count for each kernel offset
        """
        # Python reference for SpConv Indice Generation
        
        # 1. Setup
        ndim = indices_in.shape[1] - 1 # B + D coords
        device = indices_in.device
        N = indices_in.shape[0]
        
        # Ensure lists
        if isinstance(kernel_size, int): kernel_size = [kernel_size] * ndim
        if isinstance(stride, int): stride = [stride] * ndim
        if isinstance(padding, int): padding = [padding] * ndim
        if isinstance(dilation, int): dilation = [dilation] * ndim
        
        # Calculate kernel volume
        K = 1
        for k in kernel_size: K *= k
        
        # We need to iterate kernel offsets
        # Generate offsets
        # e.g. 3x3 -> (-1, -1) to (1, 1) relative to center?
        # SpConv layout: D, H, W order?
        # Usually Row Major. 
        # let's assume standard recursion to get offsets.
        
        offsets = []
        def get_offsets(d):
            if d == ndim:
                return [[]]
            res = []
            lower = 0
            upper = kernel_size[d]
            for i in range(lower, upper):
                # calc spatial offset
                # offset relative to input pos?
                # Conv: out = (in + 2p - k*d) / s ? No.
                # Standard: out_pos * s - p + k*d = in_pos
                # Inverse: in_pos -> out positions?
                # SpConv iterates Input Points.
                # For input 'p', it affects outputs 'q' such that p is in receptive field of q.
                # q = (p + p_d - k_d) / s ?
                
                # Correct Logic for Sparse Conv (Submanifold or Regular):
                # Regular:
                # Iterate kernel offsets 'k'.
                # For each input 'p':
                #   target output 'q' = (p * 1 + padding - k * dil) / stride ???
                # NO.
                # Forward:
                # out[q] = sum(in[p] * w[k])
                # q * s = p + P - k*D
                # => q = (p + P - k*D) / s
                # Must check divisibility.
                
                # We need to iterate all 'k' for each 'p'.
                
                subs = get_offsets(d+1)
                for s in subs:
                    res.append([i] + s)
            return res
            
        kernel_offsets = get_offsets(0) # List of [kz, ky, kx]
        
        # Storage
        out_to_idx = {} # (b, z, y, x) -> idx in indices_out initialization
        indices_out_list = []
        
        # Pairs: K lists
        pairs_in = [[] for _ in range(K)]
        pairs_out = [[] for _ in range(K)]
        
        # Iterate inputs
        # Slow Python Loop OK for Reference
        # Note: B is indices_in[:, 0]
        
        # Precompute offset logic
        # For each Dim i:
        # q_i = (p_i + pad[i] - k_i * dil[i]) // stride[i]
        # check (q_i * s + k*d - p) == p_i + pad ? No
        # check if (p_i + pad[i] - k_i * dil[i]) is divisible by stride[i]
        
        inputs_list = indices_in.tolist()
        
        for p_idx, p in enumerate(inputs_list):
            b = p[0]
            coords = p[1:]
            
            for k_idx, k_off in enumerate(kernel_offsets):
                # calc output
                q_coords = []
                valid = True
                
                for d in range(ndim):
                    # conv logic
                    # q = (p + pad - k*dil) / s
                    val = coords[d] + padding[d] - k_off[d] * dilation[d]
                    if val % stride[d] != 0:
                        valid = False
                        break
                    val //= stride[d]
                    
                    if val < 0 or val >= out_shape[d]:
                        valid = False
                        break
                    q_coords.append(val)
                
                if valid:
                    # Found valid output q
                    q_tuple = tuple([b] + q_coords)
                    
                    if q_tuple not in out_to_idx:
                        out_to_idx[q_tuple] = len(indices_out_list)
                        indices_out_list.append(q_tuple)
                    
                    # Add pair
                    q_idx = out_to_idx[q_tuple]
                    pairs_in[k_idx].append(p_idx)
                    pairs_out[k_idx].append(q_idx)

        # Finalize
        if len(indices_out_list) > 0:
            indices_out = torch.tensor(indices_out_list, device=device, dtype=indices_in.dtype)
        else:
            indices_out = torch.zeros(0, ndim+1, device=device, dtype=indices_in.dtype)
            
        # pairs tensor
        max_act = 0
        pair_counts = []
        for p in pairs_in:
            max_act = max(max_act, len(p))
            pair_counts.append(len(p))
            
        # Pad with -1 or just use counts
        # SpConv returns (K, 2, MaxAct) usually? Or compact?
        # Usually compact with `indiceNum`
        # We fill with -1
        
        pairs_tensor = torch.full((K, 2, max_act), -1, device=device, dtype=torch.int32)
        pair_counts_tensor = torch.tensor(pair_counts, device=device, dtype=torch.int32)
        
        for k_idx in range(K):
            cnt = pair_counts[k_idx]
            if cnt > 0:
                pairs_tensor[k_idx, 0, :cnt] = torch.tensor(pairs_in[k_idx], device=device, dtype=torch.int32)
                pairs_tensor[k_idx, 1, :cnt] = torch.tensor(pairs_out[k_idx], device=device, dtype=torch.int32)
                
        return indices_out, pairs_tensor, pair_counts_tensor

def get_init_inputs():
    return []

def get_inputs():
    indices = torch.tensor([[0, 2, 2, 2]], dtype=torch.int32)
    k = [3, 3, 3]
    s = [1, 1, 1]
    p = [0, 0, 0]
    d = [1, 1, 1]
    out = [5, 5, 5]
    return [indices, k, s, p, d, out]
