from torch import Tensor
from torch.nn import Module, Parameter
import torch
import torch.nn.functional as F

class UniformPrior(Module):
    def __init__(self, n_entities: int):
        super(UniformPrior, self).__init__()
        self.n_entities = n_entities
        self.register_buffer('uniform_prob', torch.tensor(1.0 / n_entities))

    def forward(self, batch_size: int) -> Tensor:
        return self.uniform_prob.expand(batch_size, self.n_entities)


class LearnableMarginalPrior(Module):
    def __init__(self, n_entities: int):
        super(LearnableMarginalPrior, self).__init__()
        self.u = Parameter(torch.zeros(n_entities))

    def forward(self, batch_size: int) -> Tensor:
        probabilities = F.softmax(self.u, dim=0)
        return probabilities.unsqueeze(0).repeat(batch_size, 1)


class FrequencyPrior(Module):
    def __init__(self,
                 raw_frequencies: Tensor):
        super(FrequencyPrior, self).__init__()
        self.relative_frequencies = raw_frequencies / torch.sum(raw_frequencies)
        
    def forward(self, batch_size: int) -> Tensor:
        return self.relative_frequencies.unsqueeze(0).repeat(batch_size, 1)

