import torch


@torch.no_grad()
def compute_ll_from_transition_scores(transition_scores, batch_lens, length_normalize=False):
    cumsum = torch.cumsum(transition_scores, -1)
    index_tensor = ((torch.ones_like(transition_scores).cumsum(-1)-1-batch_lens.unsqueeze(1))==0)
    value = cumsum[index_tensor]
    if length_normalize:
        value = value / (batch_lens+1)
    return value


# now try the semantic entropy
def compute_clustered_entropy(
    clusters, 
    norm_log_scores
):
    ent = 0
    for c in clusters.unique():
        clust_mask = clusters==c
        logp = torch.logsumexp(norm_log_scores[clust_mask], dim=0)
        # the correct estimator, -p*log(p)
        ent -= logp*torch.exp(logp) 
    return ent