import torch
import warnings

from typing import Tuple, Union, List, Optional
from .metadata import EntryUpdater
from ..utils import vrange


class AddrRanges(List[Tuple[int, int]]):
    def to_tensor(self, num_bytes: int, sparse: bool = True) -> torch.Tensor:
        """Generate tensor representation of address ranges for a binary
        
        Args:
            addr_ranges: A list of disjoint address ranges of the form `(start, end)`. Each range includes `start` and 
                excludes `end`. Addresses are encoded as base-10 integers.
            num_bytes: Number of bytes in the binary file
            sparse: Whether the output should be in sparse format
            
        Returns:
            A tensor where each entry (address) contains an integer label in the set {0, 1, ..., len(addr_ranges)}. 
            If an address is  not correspond to an instruction it is labeled 0, otherwise it is labeled according to the 
            unique instruction id.
        """
        addr_ranges = torch.tensor(self, dtype=int)
        addr_ranges = torch.reshape(addr_ranges, (-1, 2))
        num_insn = addr_ranges.size(0)
        insn_ids = torch.arange(1, num_insn + 1)
        repeats = (torch.diff(addr_ranges).squeeze() if num_insn else
                torch.empty(0, dtype=int))
        v = torch.repeat_interleave(insn_ids, repeats)
        i = vrange(addr_ranges[:, 0], addr_ranges[:, 1])[None, :]
        if sparse:
            out = torch.sparse_coo_tensor(i, v, (num_bytes,))
        else:
            out = torch.zeros(num_bytes, dtype=v.dtype)
            out[i] = v
        return out


class AddrRangesUpdater(EntryUpdater[torch.IntTensor]):
    meta_name = "Address ranges"

    def __init__(self, warn_replace: bool = False, warn_insert: bool = False, warn_delete: bool = False) -> None:
        """
        Args:
            warn_replace: Whether to emit a warning if bytes are replaced within any of the address ranges.
            warn_insert: Whether to emit a warning if bytes are inserted.
        """
        self.warn_replace = warn_replace
        self.warn_insert = warn_insert
        self.warn_delete = warn_delete

    def insert(self, entry: torch.IntTensor, pos: int, value: bytes) -> Optional[torch.IntTensor]:
        if entry.ndim != 1:
            raise ValueError("`metadata` must be a 1d tensor")
        
        if value:
            if self.warn_insert:
                warnings.warn(f"Bytes were inserted. {self.meta_name} may no longer be valid.", RuntimeWarning)
        
            if entry.layout == torch.sparse_coo:
                size = len(value)
                indices, values = entry._indices(), entry._values()
                new_indices = torch.where(indices >= pos, indices + size, indices)
                new_size = (entry.size()[0] + size,) + entry.size()[1:]
                entry = torch.sparse_coo_tensor(new_indices, values, new_size)
            else:
                fill = torch.zeros(len(value), dtype=int)
                entry = torch.cat((entry[:pos], fill, entry[pos:]))
        
        return entry

    def delete(self, entry: torch.IntTensor, pos: Union[int, range]) -> Optional[torch.IntTensor]:
        pos = self._normalize_pos(pos)
        
        if entry.ndim != 1:
            raise ValueError("`entry` must be a 1d tensor")
        
        if pos.stop > pos.start:
            # Range is non-empty
            if entry.layout == torch.sparse_coo:
                size = pos.stop - pos.start
                indices, values = entry._indices(), entry._values()
                mask_keep = (indices[0] >= pos.stop) | (indices[0] < pos.start)
                if self.warn_delete and torch.all(mask_keep):
                    warnings.warn(f"Bytes were deleted. {self.meta_name} may no longer be valid.", RuntimeWarning)
                new_indices = indices[:, mask_keep]
                new_values = values[mask_keep]
                new_indices[0] = torch.where(new_indices[0] >= pos.stop, new_indices[0] - size, new_indices[0])
                new_size = (entry.size()[0] - size,) + entry.size()[1:]
                entry = torch.sparse_coo_tensor(new_indices, new_values, new_size)
            else:
                if self.warn_delete and torch.count_nonzero(entry[pos.start:pos.stop]) > 0:
                    warnings.warn(f"Bytes were deleted. {self.meta_name} may no longer be valid.", RuntimeWarning)
                entry = torch.cat((entry[:pos.start], entry[pos.stop:]))
        
        return entry

    def replace(self, entry: torch.IntTensor, pos: Union[int, range], value: bytes) -> Optional[torch.IntTensor]:
        pos = self._normalize_pos(pos)
        
        pos_len = pos.stop - pos.start
        if len(value) != pos_len:
            raise ValueError(f"`pos` to replace covers {pos_len} bytes, but `value` contains {len(value)} bytes")
        
        if entry.ndim != 1:
            raise ValueError("`entry` must be a 1d tensor")
        
        if pos.stop > pos.start and self.warn_replace:
            # Range is non-empty
            num_included = 0
            if entry.layout == torch.sparse_coo:
                indices = entry._indices()[0]
                num_included = torch.sum((indices >= pos.start) & (indices < pos.stop))
            else:
                num_included = torch.sum(entry[pos.start:pos.stop] != 0)
            
            if num_included != 0:
                warnings.warn(f"Bytes were replaced within address ranges. {self.meta_name} may no longer be valid.", RuntimeWarning)
        return entry