import torch
import torch.nn as nn

class SelectionLayer(nn.Module):
    def __init__(self, num_seeds, device='cpu'):
        super().__init__()
        self.device = device
        self.weights = nn.Parameter(torch.tensor([1/num_seeds for i in range(num_seeds)]), True)
        self.to(device)
        
    def get_params(self):
        return self.weights
    
    # inputs are tensors of batch size x num_seeds x latent space dimension
    def forward(self, inputs):

        inputs  = inputs.clone().to(self.device)
        LATENT_DIM = inputs.shape[2]
        BATCH_SIZE = inputs.shape[0]
        
        #w_temp = nn.functional.softmax(self.weights, dim=0)

        # expand weights to match the input tensor
        w = torch.unsqueeze(self.weights, dim=1).expand(-1, LATENT_DIM)
        w = w.expand(BATCH_SIZE, -1, -1)

        inputs *= w

        # returning a tensor batch size x latent space dimension
        return torch.sum(inputs, dim=1)