import torch
import torch.nn as nn
from torch.distributions.multinomial import Multinomial

__all__ = ['RandomSelection']

class RandomSelection(nn.Module):
    def __init__(self, x_dim=1, y_dim=1):
        super(RandomSelection, self).__init__()
        self.name = 'random'
        self.x_dim = x_dim
        self.y_dim = y_dim

    def forward(self, D, k=20):
        with torch.no_grad():
            B, S, H = D.size()
            m = Multinomial(total_count=k, probs=torch.ones(B, S).to(D.device))
            random_mask = m.sample().unsqueeze(2)
            return random_mask
