from typing import Union, Tuple

import torch


def maybe_broadcast(
        value: Union[float, int, torch.Tensor],
        target_shape: Tuple[int, ...],
        device: torch.device,
        dtype: torch.dtype,
) -> torch.Tensor:
    """
    Return a tensor of shape `target_shape` on (device, dtype).
    - If `value` is a scalar (Python number or 0-dim/numel==1 tensor), broadcast it.
    - If `value` is a tensor with shape exactly `target_shape`, convert dtype/device.
    - Otherwise, raise.
    """
    if isinstance(value, torch.Tensor):
        v = value.to(device=device, dtype=dtype)
        if v.ndim == 0 or v.numel() == 1:
            return v.reshape(()).expand(target_shape)
        if tuple(v.shape) != tuple(target_shape):
            raise ValueError(f"Tensor must have shape {target_shape}, got {tuple(v.shape)}")
        # ensure not a broadcasted/strided view. Make the tensor writeable & contiguous
        if 0 in v.stride():
            v = v.clone()
        return v.contiguous()
    else:
        return torch.full(target_shape, value, device=device, dtype=dtype)


def broadcast_bound(bound: Union[float, torch.Tensor], reference_tensor: torch.Tensor) -> torch.Tensor:
    # Handle bounds - ensure it broadcasts properly with the input shape
    if not isinstance(bound, torch.Tensor) or bound.numel() == 1:
        # For arbitrary shapes, we need bounds to match all dimensions except the last (vocab) dimension
        bound_shape = reference_tensor.shape[:-1] + (1,)  # (..., 1)
        bound = bound * torch.ones(bound_shape, device=reference_tensor.device, dtype=reference_tensor.dtype)
    else:
        # Ensure bound has the right shape - should be (..., 1) where ... matches input shape[:-1]
        expected_shape = reference_tensor.shape[:-1]
        if bound.shape[:-1] != expected_shape:
            bound = bound.view(expected_shape + (1,))
    return bound
