import torch

def logmeanexp(x, dim=0, eps=1e-7):
    xmax, _ = torch.max(x, dim=dim, keepdim=True)
    return torch.mean((x - xmax).exp(), dim=dim).clamp(min=eps).log() + torch.squeeze(xmax, dim=dim)

def scatter_normmeanexp(src, index, k=0, dim=-1):
    # construct target tensor
    target_shape = list(src.shape)
    target_shape[dim] = index.max().item() + 1
    target = torch.full(target_shape, 0.).to(src.device)
    
    src_max, _ = torch.max(src, dim=dim, keepdim=True)
    # scatter with src.exp() with mean
    target_logits = target.scatter_reduce(dim, index, (src - src_max).exp(), reduce="mean", include_self=False).log()
    target_logits = target_logits + src_max

    # top_k filter
    if k > 0:
        k = min(target_logits.size(dim), k)
        topk_logits, _ = target_logits.topk(k=k, dim=dim)
        target_logits[target_logits < topk_logits[:, [-1]]] = -float("inf")

    return torch.nn.Softmax(-1)(target_logits).to(src.dtype)