from typing import Union
import torch

from typing import Optional, Sequence
from ..metadata import Metadata
from ..types import ByteBinarySample, IntBinary, IntBinarySample

def _check_dimension(binary: torch.ByteTensor) -> None:
    if binary.ndim != 1:
        raise ValueError(f"expected 1-dimensional tensor, but got {binary.ndim}-dimensional tensor")


def apply_mask(binary: Union[torch.ByteTensor, IntBinary], index: Sequence[int], mask_value: int = 256, 
               metadata: Optional[Metadata] = None) -> IntBinarySample:
    _check_dimension(binary)
    if isinstance(binary, torch.ByteTensor):
        binary = binary.int()
    index = torch.as_tensor(index, dtype=int, device=binary.device)
    binary[index] = mask_value
    if metadata:
        for i in index:
            metadata.mask_bytes(i)
    return binary, metadata


def trim(binary: torch.ByteTensor, length: int, metadata: Optional[Metadata] = None) -> ByteBinarySample: 
    _check_dimension(binary)
    # Clip oversized binary
    trim_size = length - binary.size(0)
    if trim_size > 0:
        binary = binary[:length]
        if metadata:
            metadata.delete_bytes(range(length, length + trim_size))
    return binary, metadata