import torch, torch.nn as nn

class DynamicDrop(nn.Module):

    def __init__(self, mode='spatial', p=0.25):
        super().__init__()
        assert mode in {'spatial', 'channel'}
        self.mode = mode
        self.register_buffer('p', torch.tensor(p, dtype=torch.float))

    @torch.no_grad()
    def update_p(self, new_p):
        self.p.fill_(float(new_p))

    def forward(self, feat_map, cam):
        p = float(self.p)

        if self.mode == "spatial":
            thresh = torch.quantile(
                cam.flatten(2),  # (B,1,H*W)
                1 - p,
                dim=2,
                keepdim=True)
            mask = (cam < thresh.unsqueeze(-1)).float()
            return feat_map * mask
        else:  # channel
            impor = (cam * feat_map).mean(dim=(2, 3))  # (B,C)
            thresh = torch.quantile(impor, 1 - p, dim=1, keepdim=True)
            mask = (impor < thresh).float().unsqueeze(-1).unsqueeze(-1)
            return feat_map * mask

