import torch
import warnings

from typing import Union, Optional
from .metadata import EntryUpdater


class HeaderSize(int):
    def to_tensor(self) -> torch.Tensor:
        """Convert the header_size to tensor value
        """
        return torch.tensor(self, dtype=int)


class HeaderSizeUpdater(EntryUpdater[int]):
    def __init__(self, warn_replace: bool = False, warn_insert: bool = False, warn_delete: bool = False) -> None:
        """
        Args:
            warn_replace: Whether to emit a warning if bytes are replaced within any of the header ranges.
            warn_insert: Whether to emit a warning if bytes are inserted.
        """
        self.warn_replace = warn_replace
        self.warn_insert = warn_insert
        self.warn_delete = warn_delete

    def insert(self, entry: int, pos: int, value: bytes) -> Optional[int]:
        if entry.ndim != 0:
            raise ValueError("`metadata` must be a scalar")

        if pos < entry:
            if self.warn_insert:
                warnings.warn("Bytes were inserted into header. Header size may no longer be valid.", RuntimeWarning)
            entry += len(value)
        return entry

    def delete(self, entry: int, pos: Union[int, range]) -> Optional[int]:
        pos = self._normalize_pos(pos)

        if  self.warn_delete and pos.start < entry:
            warnings.warn("Bytes were deleted from header. Header may no longer be valid.", RuntimeWarning)
        
        if pos.stop > pos.start and pos.start < entry:
            n_del = min(pos.stop, entry) - pos.start
            entry -= n_del
        return entry

    def replace(self, entry: int, pos: Union[int, range], value: bytes) -> Optional[int]:
        """Replace doesn't affect this metadata
        """
        if self.warn_replace and pos < entry:
            warnings.warn("Bytes were deleted from header. Header may no longer be valid.", RuntimeWarning)
        return entry
