import torch


def batched_pooling(blocks, verts_pos, image_size=224):
    # convert vertex positions to x,y coordinates in the image, scaled to fractions of image dimension

    full_features = None
    batch_size = verts_pos.shape[0]

    for block in blocks:
        # scale coordinated to block dimensions/resolution
        dim = block.shape[-1]

        xs = verts_pos[:, :, 0] / image_size
        ys = verts_pos[:, :, 1] / image_size

        cur_xs = torch.clamp(xs * dim, 0, dim - 1)
        cur_ys = torch.clamp(ys * dim, 0, dim - 1)

        cur_xs = torch.clamp(xs * dim, 0, dim - 1)
        cur_ys = torch.clamp(ys * dim, 0, dim - 1)

        # this is basically bilinear interpolation of the 4 closest feature vectors to where the vertex lands in the block
        # https://en.wikipedia.org/wiki/Bilinear_interpolation
        x1s, y1s, x2s, y2s = (
            torch.floor(cur_xs),
            torch.floor(cur_ys),
            torch.ceil(cur_xs),
            torch.ceil(cur_ys),
        )
        A = x2s - cur_xs
        B = cur_xs - x1s
        G = y2s - cur_ys
        H = cur_ys - y1s

        x1s = x1s.long()
        y1s = y1s.long()
        x2s = x2s.long()
        y2s = y2s.long()

        flat_block = block.permute(1, 0, 2, 3).contiguous().view(block.shape[1], -1)
        block_upper = (
            torch.arange(0, verts_pos.shape[0])
            .to(block.device)
            .unsqueeze(-1)
            .expand(batch_size, verts_pos.shape[1])
        )

        selection = ((block_upper * dim * dim) + (x1s * dim) + y1s).view(-1)
        C = torch.index_select(flat_block, 1, selection)
        C = C.view(-1, batch_size, verts_pos.shape[1]).permute(1, 0, 2)
        selection = ((block_upper * dim * dim) + (x1s * dim) + y2s).view(-1)
        D = torch.index_select(flat_block, 1, selection)
        D = D.view(-1, batch_size, verts_pos.shape[1]).permute(1, 0, 2)
        selection = ((block_upper * dim * dim) + (x2s * dim) + y1s).view(-1)
        E = torch.index_select(flat_block, 1, selection)
        E = E.view(-1, batch_size, verts_pos.shape[1]).permute(1, 0, 2)
        selection = ((block_upper * dim * dim) + (x2s * dim) + y2s).view(-1)
        F = torch.index_select(flat_block, 1, selection)
        F = F.view(-1, batch_size, verts_pos.shape[1]).permute(1, 0, 2)

        section1 = A.unsqueeze(1) * C * G.unsqueeze(1)
        section2 = H.unsqueeze(1) * D * A.unsqueeze(1)
        section3 = G.unsqueeze(1) * E * B.unsqueeze(1)
        section4 = B.unsqueeze(1) * F * H.unsqueeze(1)

        features = section1 + section2 + section3 + section4
        features = features.permute(0, 2, 1)

        if full_features is None:
            full_features = features
        else:
            full_features = torch.cat((full_features, features), dim=2)

    return full_features
