from typing import List, Union

import numpy as np
import torch


def inv_normalize(
    x: Union[torch.Tensor, np.ndarray],
    mean: Union[float, List[float], torch.Tensor, np.ndarray] = [0.5, 0.5, 0.5],
    std: Union[float, List[float], torch.Tensor, np.ndarray] = [0.5, 0.5, 0.5],
    scale_back: bool = False,
) -> Union[torch.Tensor, np.ndarray]:
    """
    Recover an image tensor or array that was normalized by (x - mean) / std.
    Supports inputs with shape (B, T, C, H, W) or (B, C, H, W).

    Args:
        x: Normalized input, torch.Tensor or numpy.ndarray.
        mean: Per-channel mean used in normalization. Can be a scalar,
            list of length C, or tensor/array broadcastable to x.
        std: Per-channel std used in normalization. Same shape rules as mean.
        scale_back: If True, after denormalization scale values from [0,1] to [0,255].
            Otherwise leave values in original range.

    Returns:
        Denormalized image of same type and shape as input x.
    """
    # Convert mean and std to tensors on same device as x
    is_tensor = isinstance(x, torch.Tensor)
    if is_tensor:
        mean_t = torch.tensor(mean, dtype=x.dtype, device=x.device)
        std_t = torch.tensor(std, dtype=x.dtype, device=x.device)
        # reshape for broadcasting: assume channels at dim=-3
        shape = [1] * x.ndim
        shape[-3] = -1
        mean_t = mean_t.view(shape)
        std_t = std_t.view(shape)
        x_denorm = x * std_t + mean_t
        x_denorm = torch.clamp(x_denorm, 0, 1)
    else:
        mean_a = np.array(mean, dtype=x.dtype)
        std_a = np.array(std, dtype=x.dtype)
        # reshape for broadcasting
        shape = [1] * x.ndim
        shape[-3] = -1
        mean_a = mean_a.reshape(shape)
        std_a = std_a.reshape(shape)
        x_denorm = x * std_a + mean_a
        x_denorm = np.clip(x_denorm, 0, 1)

    if scale_back:
        x_denorm = x_denorm * 255.0
    return x_denorm
