import torch
import torch.nn.functional as F
import torch.nn as nn

class GeMPool(nn.Module):
    """Implementation of GeM as in https://github.com/filipradenovic/cnnimageretrieval-pytorch
    we add flatten and norm so that we can use it as one aggregation layer.
    """
    def __init__(self, p=3, eps=1e-6):
        super().__init__()
        self.p = nn.Parameter(torch.ones(1)*p)
        self.eps = eps

    def forward(self, x): # [B, C, H, W] # x feature tensor : torch.Size([240, 768, 16, 16])
        x, t = x # Extract features and token # x feature tensor : torch.Size([240, 768, 16, 16])     t token tensor : torch.Size([240, 768])
        x = F.avg_pool2d(x.clamp(min=self.eps).pow(self.p), (x.size(-2), x.size(-1))).pow(1./self.p) # [B, C, 1, 1] # torch.Size([240, 768, 1, 1])
        x = x.flatten(1) # [B, C] # torch.Size([240, 768])
        return F.normalize(x, p=2, dim=1)
    
