import torch
import numpy as np
from typing import Dict, TypeVar, Union, Optional, Generic, Any, Tuple


Entry = TypeVar('Entry')


class EntryUpdater(Generic[Entry]):
    """Updater for a metadata entry
    
    Provides functions for updating an entry when the corresponding binary is transformed by byte insertion, 
    deletion or replacement operations.

    This base class returns the metadata entry without updating for all transformation operations.
    """
    def insert(self, entry: Entry, pos: int, value: bytes) -> Optional[Entry]:
        """Update metadata entry after inserting bytes in the binary
        
        Args:
            entry: Metadata entry before update
            pos: Position in the original binary before which additional bytes were inserted
            value: Inserted bytes
        
        Return:
            Updated entry. A value of `None` means the metadata entry is to be removed.
        """
        return entry

    def delete(self, entry: Entry, pos: Union[int, range]) -> Optional[Entry]:
        """Update metadata entry after deleting bytes in the binary
        
        Args:
            entry: Metadata entry before update
            pos: Positions in the original binary where bytes were deleted
        
        Return:
            Updated entry. A value of `None` means the metadata entry is to be removed.
        """
        pos = self._normalize_pos(pos)
        return entry

    def replace(self, entry: Entry, pos: Union[int, range], value: bytes) -> Optional[Entry]:
        """Update metadata entry after replacing bytes in the binary
        
        Args:
            entry: Metadata entry before update
            pos: Positions in the original binary where bytes were replaced
            value: Values of the replaced bytes referenced in `pos`
        
        Return:
            Updated entry. A value of `None` means the metadata entry is to be removed.
        """
        pos = self._normalize_pos(pos)
        return entry

    def mask(self, entry: Entry, pos: Union[int, range]) -> Optional[Entry]:
        """Update metadata entry after masking bytes in the binary
        
        Args:
            entry: Metadata entry before update
            pos: Positions in the original binary where bytes were masked
        
        Return:
            Updated entry. A value of `None` means the metadata entry is to be removed.
        """
        pos = self._normalize_pos(pos)
        return entry

    @staticmethod
    def _normalize_pos(pos: Union[int, range]) -> range:
        if isinstance(pos, (torch.Tensor, np.ndarray)):
            # Assume tensor/ndarray with one element. This will be casted to a range below.
            pos = pos.item()
        if isinstance(pos, range):
            if pos.step and pos.step != 1:
                raise ValueError("`pos` must be a slice with step 1")
        elif isinstance(pos, int):
            pos = range(pos, pos + 1)
        return pos


class Metadata(Dict[str, Any]):
    """Metadata for a binary
    
    This class is used to store metadata extracted while opening and/or analysing a binary file. It offers a familiar 
    interface, since it derives from a Python dictionary. Entries in the metadata correspond to key-value pairs in 
    the underlying dictionary. The class deviates from a standard Python dictionary, as it provides functionality for 
    updating metadata entries when the corresponding binary is transformed by byte insertion, deletion or replacement 
    operations.
    """
    updaters: Dict[str, EntryUpdater] = {}
    
    def set_updater(self, key: str, updater: EntryUpdater) -> None:
        if key in self.keys():
            self.updaters[key] = updater
        else:
            raise KeyError(key)
    
    def insert_bytes(self, pos: int, value: bytes) -> None:
        """Update metadata after inserting bytes in the binary
        
        Args:
            pos: Position in the original binary before which additional bytes were inserted
            value: Inserted bytes
        """
        for key in list(self):
            updater = self.updaters.get(key, None)
            if updater:
                new_entry = updater.insert(self[key], pos, value)
                if new_entry is not None:
                    self[key] = new_entry
                else:
                    del self[key]

    def delete_bytes(self, pos: Union[int, range]) -> None:
        """Update metadata after deleting bytes in the binary
        
        Args:
            entry: Metadata before update
            pos: Positions in the original binary where bytes were deleted
        """
        for key in list(self):
            updater = self.updaters.get(key, None)
            if updater:
                new_entry = updater.delete(self[key], pos)
                if new_entry is not None:
                    self[key] = new_entry
                else:
                    del self[key]

    def replace_bytes(self, pos: Union[int, range], value: bytes) -> None:
        """Update metadata after replacing bytes in the binary
        
        Args:
            pos: Positions in the original binary where bytes were replaced
            value: Values of the replaced bytes referenced in `pos`
        """
        for key in list(self):
            updater = self.updaters.get(key, None)
            if updater:
                new_entry = updater.replace(self[key], pos, value)
                if new_entry is not None:
                    self[key] = new_entry
                else:
                    del self[key]

    def mask_bytes(self, pos: Union[int, range]) -> None:
        """Update metadata after masking bytes in the binary
        
        Args:
            pos: Positions in the original binary where bytes were masked
        """
        for key in list(self):
            updater = self.updaters.get(key, None)
            if updater:
                new_entry = updater.mask(self[key], pos)
                if new_entry is not None:
                    self[key] = new_entry
                else:
                    del self[key]

    def __delitem__(self, key: str) -> None:
        if key in self.updaters.keys():
            del self.updaters[key]
        return super().__delitem__(key)

    def __getstate__(self):
        return self.updaters, dict(self)
    
    def __setstate__(self, state):
        self.updaters, data = state
        self.update(data)

    def __reduce__(self) -> Union[str, Tuple[Any, ...]]:
        return Metadata, (), self.__getstate__()

    def to(self, device: Union[str, torch.DeviceObjType]) -> 'Metadata':
        for key in self.keys():
            if isinstance(self[key], torch.Tensor):
                self[key] = self[key].to(device)
        return self