
import torch
import numpy as np

from typing import Optional, Sequence

from ..types import ByteBinary, ByteBinarySample, IntBinarySample, Binary
from ..metadata import Metadata
from . import functional_bytes as Fb
from . import functional_tensor as Ft


def to_tensor(binary: ByteBinary, dtype = torch.uint8) -> torch.Tensor:
    if isinstance(binary, torch.Tensor):
        return binary.to(dtype)
    # we need to first cast to uint8 to make sure its proper.
    return torch.frombuffer(binary, dtype=torch.uint8).type(dtype)

def to_numpy(binary: ByteBinary, dtype = np.uint8) -> np.ndarray:
    if isinstance(binary, torch.Tensor):
        return binary.cpu().detach().numpy().astype(dtype)
    return np.frombuffer(binary, dtype=np.uint8).astype(dtype)


def to_bytes(binary: torch.ByteTensor) -> bytes:
    if isinstance(binary, (bytes, bytearray)):
        return binary
    
    if binary.ndim != 1:
        raise ValueError("`binary` must be a 1-dimensional tensor")
    
    return binary.detach().cpu().numpy().tobytes()


def apply_mask(binary: Binary, index: Sequence[int], mask_value: int = 256, 
               metadata: Optional[Metadata] = None) -> IntBinarySample:
    if isinstance(binary, torch.Tensor):
        return Ft.apply_mask(binary, index, mask_value=mask_value, metadata=metadata)
    return Fb.apply_mask(binary, index, mask_value=mask_value, metadata=metadata)


def zero_pe_header(binary: ByteBinary, metadata: Optional[Metadata] = None) -> ByteBinarySample:
    binary = to_bytes(binary)
    return Fb.zero_pe_header(binary, metadata=metadata)


def remove_pe_header(binary: ByteBinary, metadata: Optional[Metadata] = None) -> ByteBinarySample:
    binary = to_bytes(binary)
    return Fb.remove_pe_header(binary, metadata=metadata)


def remove_pe_sections(binary: bytes, sec_names: Sequence[str], metadata: Optional[Metadata] = None) -> ByteBinarySample:
    binary = to_bytes(binary)
    return Fb.remove_pe_sections(binary, sec_names, metadata=metadata)


def add_insn_addr_ranges(binary: ByteBinary, metadata: Optional[Metadata] = None, 
                         write_stdout: bool = True, write_stderr: bool = True) -> ByteBinarySample:
    binary = to_bytes(binary)
    return Fb.add_insn_addr_ranges(binary, metadata=metadata, write_stdout=write_stdout, 
                                   write_stderr=write_stderr)


def add_exe_section_ranges(binary: ByteBinary, metadata: Optional[Metadata] = None) -> ByteBinarySample:
    binary = to_bytes(binary)
    return Fb.add_exe_section_ranges(binary, metadata=metadata)


def add_header_size(binary: ByteBinary, metadata: Optional[Metadata] = None) -> ByteBinarySample:
    binary = to_bytes(binary)
    return Fb.add_header_size(binary, metadata=metadata)


def trim(binary: ByteBinary, length: int, metadata: Optional[Metadata] = None) -> ByteBinarySample:
    if isinstance(binary, torch.Tensor):
        return Ft.trim(binary, length, metadata=metadata)
    return Fb.trim(binary, length, metadata=metadata)


def mask_non_instruction(binary: Binary, metadata: Metadata,
                         mask_value: int = 257) -> IntBinarySample:
    # Get index to mask from metadata
    insn_addr_ranges = metadata['insn_addr']
    if insn_addr_ranges.ndim != 1:
        raise ValueError("`insn_addr_ranges` must be a 1-dimensional tensor")
    
    if insn_addr_ranges.layout == torch.sparse_coo:
        insn_bytes_idx = insn_addr_ranges._indices().squeeze(0)
    else:
        insn_bytes_idx = insn_addr_ranges.nonzero().squeeze(1)

    insn_bytes_idx = set(insn_bytes_idx.tolist())
    index = torch.tensor([idx for idx in range(insn_addr_ranges.size(0)) if idx not in insn_bytes_idx])
    return apply_mask(binary=binary, index=index, mask_value=mask_value)[0], metadata
    
    # index = torch.arange(len(binary))
    # index = index[~torch.isin(index, insn_bytes_idx)]
    # return apply_mask(binary=binary, index=index, mask_value=mask_value, metadata=metadata)