import torch
import torch.nn.functional as F


def sample_grid(grid, low_coords, high_coords):
    assert (low_coords > high_coords).sum() == 0
    if low_coords.ndim == 2:
        low_coords = low_coords.view(1, 1, *low_coords.shape)
        high_coords = high_coords.view(1, 1, *high_coords.shape)
    if low_coords.ndim == 3:
        low_coords = low_coords.view(1,  *low_coords.shape)
        high_coords = high_coords.view(1,  *high_coords.shape)

    '''
    '''
    scale = torch.tensor([1, 1])
    if torch.rand(()) >= 0.5:
        grid = grid.flip(dims=[-2]) # Y
        scale = scale * torch.tensor([1, -1])
    if torch.rand(()) >= 0.5:
        grid = grid.flip(dims=[-1]) # Y
        scale = scale * torch.tensor([-1, 1])
    scale = scale.to(grid)
    low_coords, high_coords = (torch.stack([low_coords, high_coords], 0) * scale).sort(dim=0)[0]

    _, n_chan, res0, res1 = grid.shape
    grid2 = torch.cat([grid, grid.cumsum(dim=-2), grid.cumsum(dim=-1)], 1)
    grid2 = torch.cat([grid2, grid2[:, n_chan:2*n_chan].cumsum(dim=-1)], 1)

    a = low_coords
    b = torch.stack([low_coords[..., 0], high_coords[..., 1]], -1)
    c = torch.stack([high_coords[..., 0], low_coords[..., 1]], -1)
    d = high_coords

    approx2 = (
        F.grid_sample(grid2[:, :n_chan] - grid2[:, 1*n_chan:2*n_chan]
                      - grid2[:, 2*n_chan:3*n_chan] + grid2[:, 3*n_chan:],
                      a, align_corners=True).squeeze()
        + F.grid_sample(grid2[:, 1*n_chan:2*n_chan]
                        - grid2[:, 3*n_chan:4*n_chan],
                        b, align_corners=True).squeeze()
        + F.grid_sample(grid2[:, 2*n_chan:3*n_chan]
                        - grid2[:, 3*n_chan:4*n_chan],
                        c, align_corners=True).squeeze()
        + F.grid_sample(grid2[:, 3*n_chan:], d, align_corners=True).squeeze()
    )

    diff = high_coords - low_coords
    areas = (res0 - 1) * diff[..., 0] * (res1 - 1) * diff[..., 1]
    areas = areas.squeeze()

    scale = (1 / areas).clamp(max=grid.norm(dim=1).max() / approx2.norm(dim=0).unsqueeze(0))
    approx2 = approx2 + (approx2 * (scale - 1)).detach()

    return approx2


def sample_grid_old(grid, low_coords, high_coords, eps=1e-8):
    assert (low_coords > high_coords).sum() == 0
    if low_coords.ndim == 2:
        low_coords = low_coords.view(1, 1, *low_coords.shape)
        high_coords = high_coords.view(1, 1, *high_coords.shape)
    if low_coords.ndim == 3:
        low_coords = low_coords.view(1,  *low_coords.shape)
        high_coords = high_coords.view(1,  *high_coords.shape)

    grid2 = grid.cumsum(dim=-2).cumsum(dim=-1)

    _, n_chan, res0, res1 = grid.shape
    res = torch.tensor([res0, res1]).to(low_coords)

    a = low_coords
    b = torch.stack([low_coords[..., 0] - eps, high_coords[..., 1]], -1)
    c = torch.stack([high_coords[..., 0], low_coords[..., 1] - eps], -1)
    d = high_coords - eps

    diff = high_coords - low_coords + eps
    areas = (diff * res).prod(dim=-1)
    areas = areas.squeeze() * 4

    approx = [F.grid_sample(grid2, i, align_corners=True).squeeze()
              for i in [a, b, c, d]]

    approx = approx[0] - approx[1] - approx[2] + approx[3]
    scale = (1 / areas).clamp(max=grid.norm(dim=1).max() / approx.norm(dim=0).unsqueeze(0))
    approx = approx * scale

    # approx = approx / areas
    # approx = approx * (grid.norm(dim=1).max() / approx.norm(dim=0).unsqueeze(0)).clamp(max=1)

    return approx


if __name__ == '__main__':
    n_sample = 2048 # 48 # 128
    n_chan = 20
    res = 256
    grid = torch.randn(1, n_chan, res, res) # , requires_grad=True)

    coords = torch.rand(1, n_sample, 1, 2, 2) * 2 - 1
    coords = coords * torch.linspace(-2, 0, n_sample).exp().reshape(-1, 1, 1, 1)
    coords = coords.sort(dim=-2)[0] # [1, n_samples, 1, 2, n_dim]

    monte = 8192
    exp_coords = torch.rand(1, n_sample, monte, 2)
    exp_coords = exp_coords * (coords[..., 1, :] - coords[..., 0, :]) \
               + coords[..., 0, :]

    approx = F.grid_sample(grid, exp_coords, align_corners=True)
    approx = approx.mean(dim=-1).squeeze()

    approx2 = sample_grid(grid, coords[..., 0, :].reshape(-1, 2),
                          coords[..., 1, :].reshape(-1, 2))

    # approx2.mean().backward()
    # grad = grid.grad[0]

    print(approx.norm(), approx2.norm())
    print(approx.std(), approx2.std())
    print((approx - approx2).norm(), (approx - approx2).std())
    print(grid.norm(dim=1).max(), approx2.norm(dim=0).max())

    print(torch.corrcoef(torch.stack([approx.flatten(), approx2.flatten()], 0)))

