import torch
from torch import Tensor


def map2list(x: Tensor):
    """
    Converts a tensor of shape (B, C, H, W) or (B, H, W) to (B, N, C) or (B, N),
    where N = H * W.
    Args:
        x (Tensor): Input tensor of shape (B, C, H, W) or (B, H, W).
    Returns:
        Tensor: Reshaped tensor of shape (B, N, C) or (B, N), where N = H * W.
    """
    if len(x.shape) not in (3, 4):
        raise ValueError(
            f"Expected input to have 3 or 4 dimensions, but got {len(x.shape)}."
        )
    if len(x.shape) == 3:
        return x.view(x.shape[0], -1)
    x = x.view(x.shape[0], x.shape[1], -1)
    x = torch.swapaxes(x, 1, 2)
    x = x.contiguous()
    return x


def list2map(x: Tensor, size_h: int, size_w: int):
    """
    Converts a tensor of shape (B, N, C) or (B, N) to (B, C, H, W) or (B, H, W).
    Args:
        x (Tensor): Input tensor with shape (B, N, C) or (B, N), where N = H * W.
        size_h (int): Height (H) of the output.
        size_w (int): Width (W) of the output.
    Returns:
        Tensor: Reshaped tensor of shape (B, C, H, W) or (B, H, W).
    """
    if x.shape[1] != size_h * size_w:
        raise ValueError(f"Expected N = H * W, but got N = {x.shape[1]}.")
    if len(x.shape) not in (2, 3):
        raise ValueError(
            f"Expected input to have 2 or 3 dimensions but got {len(x.shape)}."
        )
    if len(x.shape) == 2:
        return x.view(x.shape[0], size_h, size_w)
    x = torch.swapaxes(x, 1, 2)
    return x.view(x.shape[0], x.shape[1], size_h, size_w)
