import torch
import torch.nn.functional as F


def grid_sample(image, optical):
    N, C, IH, IW = image.shape
    _, H, W, _ = optical.shape

    ix = optical[..., 0]
    iy = optical[..., 1]

    ix = ((ix + 1) / 2) * (IW-1);
    iy = ((iy + 1) / 2) * (IH-1);
    with torch.no_grad():
        ix_nw = torch.floor(ix);
        iy_nw = torch.floor(iy);
        ix_ne = ix_nw + 1;
        iy_ne = iy_nw;
        ix_sw = ix_nw;
        iy_sw = iy_nw + 1;
        ix_se = ix_nw + 1;
        iy_se = iy_nw + 1;

    nw = (ix_se - ix)    * (iy_se - iy)
    ne = (ix    - ix_sw) * (iy_sw - iy)
    sw = (ix_ne - ix)    * (iy    - iy_ne)
    se = (ix    - ix_nw) * (iy    - iy_nw)

    with torch.no_grad():
        torch.clamp(ix_nw, 0, IW-1, out=ix_nw)
        torch.clamp(iy_nw, 0, IH-1, out=iy_nw)

        torch.clamp(ix_ne, 0, IW-1, out=ix_ne)
        torch.clamp(iy_ne, 0, IH-1, out=iy_ne)

        torch.clamp(ix_sw, 0, IW-1, out=ix_sw)
        torch.clamp(iy_sw, 0, IH-1, out=iy_sw)

        torch.clamp(ix_se, 0, IW-1, out=ix_se)
        torch.clamp(iy_se, 0, IH-1, out=iy_se)

    image = image.view(N, C, IH * IW)

    nw_val = torch.gather(image, 2, (iy_nw * IW + ix_nw).long().view(N, 1, H * W).repeat(1, C, 1))
    ne_val = torch.gather(image, 2, (iy_ne * IW + ix_ne).long().view(N, 1, H * W).repeat(1, C, 1))
    sw_val = torch.gather(image, 2, (iy_sw * IW + ix_sw).long().view(N, 1, H * W).repeat(1, C, 1))
    se_val = torch.gather(image, 2, (iy_se * IW + ix_se).long().view(N, 1, H * W).repeat(1, C, 1))

    out_val = (nw_val.view(N, C, H, W) * nw.view(N, 1, H, W) + 
               ne_val.view(N, C, H, W) * ne.view(N, 1, H, W) +
               sw_val.view(N, C, H, W) * sw.view(N, 1, H, W) +
               se_val.view(N, C, H, W) * se.view(N, 1, H, W))

    return out_val


if __name__ == '__main__':
    image = torch.Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).view(1, 3, 1, 3)
    optical = torch.Tensor([0.9, 0.5, 0.6, -0.7]).view(1, 1, 2, 2)
    image.requires_grad_(True)
    optical.requires_grad_(True)

    
    my_output = grid_sample(image, optical)
    print('\033[1;32m' + 'my_output: ' + '\033[0m', my_output)
    official_output = F.grid_sample(image, optical, padding_mode='border', align_corners=True)
    print('\033[1;32m' + 'official_output: ' + '\033[0m', official_output)

    
    d_output = torch.ones_like(official_output, requires_grad=False)
    my_gradients = torch.autograd.grad(outputs=my_output,
                                       inputs=[image, optical],
                                       grad_outputs=d_output,
                                       create_graph=True,
                                       retain_graph=True,
                                       only_inputs=True)[0]
    print('\033[1;32m' + 'my_gradients: ' + '\033[0m', my_gradients)
    official_gradients = torch.autograd.grad(outputs=official_output,
                                             inputs=[image, optical],
                                             grad_outputs=d_output,
                                             create_graph=True,
                                             retain_graph=True,
                                             only_inputs=True)[0]
    print('\033[1;32m' + 'official_gradients: ' + '\033[0m', official_gradients)

    
    my_gradients.mean().backward()
    print('\033[1;32m' + 'customed implementation succeed' + '\033[0m')
    official_gradients.mean().backward()
