import torch


def get_coord_grid(x_size, y_size, device=None):
    xs = torch.arange(0, x_size, device=device)
    ys = torch.arange(0, y_size, device=device)
    x, y = torch.meshgrid(xs, ys)
    
    coord_grid = torch.stack([x, y]).permute(2,1,0)
    
    return coord_grid.float()


def reconstruct_from_offset_unfold(hm, offset, ksize, expe_weight=0.8, shift=-10, slope=4):
    '''
    Implementation of differentiable detection map reconstruction given:
        - hm : a 2D heatmap at time t
        - offset : a 2D motion offset between time t and t+1
        - ksize : the sliding windows size when reconstructing motion, motion length should be at most ksize/2
        - expe_weight : correspond to lambda_r in the paper, tradeoff between precision and differentiability.
    '''


    assert ksize % 2 == 1, "reconstruction windows must be of uneven dimension" 
    
    B, C, H, W = hm.size()

    if offset is not None:
        B_o, H_o, W_o, C_o = offset.size()
        
        assert B == B_o
        assert C_o == 2
        assert H == H_o
        assert W == W_o
    
    #generate base coordinate grid flatten height and widht dimension
    coord_grid = get_coord_grid(W, H, hm.device)#reshape(-1, 2)
    
    coord_grid = coord_grid.repeat(B, 1, 1, 1)

    new_coord = coord_grid.clone()
    
    #compute future coord after applying offset
    if offset is not None:
        updated_coord = coord_grid + offset
    else:
        updated_coord = coord_grid
    
    
    kernel_h, kernel_w = ksize, ksize
    stride = 1
    p2d = (kernel_w//2, kernel_w//2, kernel_h//2, kernel_h//2)

    #Compute coordinate after accounting for motion offset
    new_coord_u = new_coord.permute(0,3,1,2).unsqueeze(4).unsqueeze(5)#torch.nn.functional.pad(new_coord.permute(0,3,1,2), p2d).unfold(2, kernel_h, stride).unfold(3, kernel_w, stride)
    updated_coord_u = torch.nn.functional.pad(updated_coord.permute(0,3,1,2), p2d).unfold(2, kernel_h, stride).unfold(3, kernel_w, stride)
    
    #Compute weights tensor by first computing the distance between location reconstructed and all the neighboring location
    distance =  -(torch.sqrt(torch.clamp(((new_coord_u - updated_coord_u)**2).sum(dim=1, keepdim=True), min=1e-8))*slope*expe_weight+shift)
    distance = distance.exp() / (distance.exp() + 1)

    #reorganize original heatmap to have same dimensionality has the weights tensor
    hm_u = torch.nn.functional.pad(hm, p2d).unfold(2, kernel_h, stride).unfold(3, kernel_w, stride)
    
    #Weighted sum for final reconstruction
    rec = (hm_u * distance).sum(dim=(4,5))

    return rec