import torch
from torch import Tensor


def logsumexp(
    x: Tensor,
    dim: int,
    weights: Tensor,
    epsilon: float,
    keepdim: bool = False,
) -> Tensor:
    """
    Returns the log of sum of exponentials of input elements along a given dimension, weighted by weights.
    Arguments:
        x (Tensor): Input tensor.
        dim (int): Dimension to reduce.
        weights (Tensor): Weights for each element in x.
        epsilon (float): Small value to avoid log(0).
        keepdim (bool, optional): Whether the output tensor has dim retained or not. Default: False.
    """
    if torch.any(weights < 0):
        raise ValueError("Weights must be non-negative.")
    weights = weights.expand_as(x)
    log_w = torch.log(weights + epsilon)
    x = x + log_w
    return torch.logsumexp(x, dim=dim, keepdim=keepdim)


def interpolate_tensor(t: Tensor, h: int, w: int, mode="bilinear"):
    """
    Takes a tensor [B, C, H0, W0] and target width and height.
    Interpolate the tensor to [B, C, H, W].
    """
    if t.shape[-1] == w and t.shape[-2] == h:
        return t
    return torch.nn.functional.interpolate(t, mode=mode, antialias=False, size=(h, w))
