# Dataset for loading data pipelines
from typing import Any, Sequence, TypeVar, Generic, overload, Tuple

import torch
from ..types import ByteBinary, IntBinarySample, ByteBinarySample, BinarySample, Binary, IntBinary
from ..metadata import Metadata
from . import functional as F

IT = TypeVar('IT')
OT = TypeVar('OT')


class Transform(Generic[IT, OT]):
    def __call__(self, input: IT) -> OT:
        pass


class Compose(Transform[Any, Any]):
    """Composes several transforms together.
    
    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.
    """
    def __init__(self, transforms: Sequence[Transform[Any, Any]]):
        self.transforms = transforms

    def __call__(self, input: Any) -> Any:
        for t in self.transforms:
            input = t(input)
        return input

    def __repr__(self):
        format_string = self.__class__.__name__ + "("
        for t in self.transforms:
            format_string += "\n"
            format_string += f"    {t}"
        format_string += "\n)"
        return format_string

    def __getitem__(self, idx):
        return self.transforms[idx]


class ZeroPEHeader:
    """Zeroes out bytes corresponding to the header in a PE binary

    Note:
    If a binary cannot be interpreted as a PE file, it will be unaffected by this transformation.
    """
    def __call__(self, input: ByteBinarySample) -> ByteBinarySample:
        binary, metadata = input
        return F.zero_pe_header(binary, metadata=metadata)
    
    def __repr__(self):
        return self.__class__.__name__ + "()"


class RemovePEHeader:
    """Removes bytes corresponding to the header in a PE binary

    Note:
    If a binary cannot be interpreted as a PE file, it will be unaffected by this transformation.
    """
    def __call__(self, input: ByteBinarySample) -> ByteBinarySample:
        binary, metadata = input
        return F.remove_pe_header(binary, metadata=metadata)
    
    def __repr__(self):
        return self.__class__.__name__ + "()"


class RemovePESections:
    """Removes bytes corresponding to specified sections in a PE binary
    
    Note:
    If a binary cannot be interpreted as a PE file, it will be unaffected by this transformation.

    Args:
        sec_names: Names of sections to remove
    """
    def __init__(self, sec_names: Sequence[str]) -> None:
        self.sec_names = sec_names

    def __call__(self, input: ByteBinarySample) -> ByteBinarySample:
        binary, metadata = input
        return F.remove_pe_sections(binary, self.sec_names, metadata=metadata)
    
    def __repr__(self):
        return self.__class__.__name__ + f"(sec_names={self.sec_names})"


class AddInsnAddrRanges:
    """Adds address ranges that correspond to machine instructions to the binary's metadata

    Args:
        write_stdout: Whether to write messages from Ghidra analysis to standard output.
        write_stderr: Whether to write error messages from Ghidra analysis to standard error.
    """
    def __init__(self, write_stdout: bool = True, write_stderr: bool = True) -> None:
        self.write_stdout = write_stdout
        self.write_stderr = write_stderr

    def __call__(self, input: ByteBinarySample) -> ByteBinarySample:
        binary, metadata = input
        return F.add_insn_addr_ranges(binary, metadata=metadata, 
                                      write_stdout=self.write_stdout, write_stderr=self.write_stderr)
    
    def __repr__(self):
        return self.__class__.__name__ + f"(write_stdout={self.write_stdout}, write_stderr={self.write_stderr})"


class AddExeSectionRanges:
    """Adds address ranges that correspond to executable sections to the binary's metadata

    Note:
        A section is regarded as executable if it is marked with the `IMAGE_SCN_MEM_EXECUTE` flag. 
        See https://reverseengineering.stackexchange.com/questions/11311/how-to-recognize-pe-sections-containing-code
    """
    def __call__(self, input: ByteBinarySample) -> ByteBinarySample:
        binary, metadata = input
        return F.add_exe_section_ranges(binary, metadata=metadata)
    
    def __repr__(self):
        return self.__class__.__name__ + f"()"


class AddHeaderSize:
    """Adds size of the PE header to the binary's metadata
    """
    def __call__(self, input: ByteBinarySample) -> ByteBinarySample:
        binary, metadata = input
        return F.add_header_size(binary, metadata=metadata)
    
    def __repr__(self):
        return self.__class__.__name__ + f"()"


class DropMetadata:
    def __init__(self, keys=None) -> None:
        self.keys = keys

    @overload
    def __call__(self, input: BinarySample) -> Binary:
        ...
    @overload
    def __call__(self, input: Binary) -> Binary:
        ...
    def __call__(self, input):
        if isinstance(input, tuple):
            binary, metadata = input
            if self.keys is not None:
                for key in self.keys:
                    metadata.pop(key, None)
            else:
                return binary
        return input
    
    def __repr__(self):
        return self.__class__.__name__ + f"(keys: {self.keys})"


class ToTensor:
    def __init__(self, dtype = torch.uint8) -> None:
        self.dtype = dtype

    @overload
    def __call__(self, input: Binary) -> torch.Tensor:
        ...
    @overload
    def __call__(self, input: BinarySample) -> Tuple[torch.Tensor, Metadata]:
        ...
    def __call__(self, input):
        if isinstance(input, tuple):
            binary, metadata = input
            return F.to_tensor(binary, dtype=self.dtype), metadata
        return F.to_tensor(input, dtype=self.dtype)
    
    def __repr__(self):
        return self.__class__.__name__ + f"(dtype={self.dtype})"


class ToBytes:
    @overload
    def __call__(self, input: ByteBinary) -> bytes:
        ...
    @overload
    def __call__(self, input: ByteBinarySample) -> Tuple[bytes, Metadata]:
        ...
    def __call__(self, input: ByteBinary) -> bytes:
        if isinstance(input, tuple):
            binary, metadata = input
            return F.to_bytes(binary), metadata
        return F.to_bytes(input)
    
    def __repr__(self):
        return self.__class__.__name__ + "()"


class Trim(Transform[ByteBinarySample, ByteBinarySample]):
    """Trim a binary to a given length if its above it
    
    Args:
        length: Length of transformed binary.
    """
    def __init__(self, length: int) -> None:
        self.length = length

    @overload
    def __call__(self, input: Binary) -> Binary:
        ...
    @overload
    def __call__(self, input: BinarySample) -> BinarySample:
        ...
    def __call__(self, input):
        if isinstance(input, tuple):
            binary, metadata = input
            return F.trim(binary, self.length, metadata=metadata)
        return F.trim(input, self.length, metadata=None)[0]
    
    def __repr__(self):
        return self.__class__.__name__ + f"(length={self.length})"

class MaskNonInstruction:
    """Masks non instruction part of a binary file

    Args:
        mask_value: Integer value used to represent a masked byte
    """
    def __init__(self, mask_value: int = 257) -> None:
        self.mask_value = mask_value

    def __call__(self, input: BinarySample) -> IntBinarySample:
        binary, metadata = input
        if self.mask_value is not None:
            return F.mask_non_instruction(binary=binary, metadata=metadata, mask_value=self.mask_value)
        else:
            binary, metadata = input
            binary = torch.frombuffer(binary, dtype=torch.uint8).int()
            return binary, metadata


    def __repr__(self) -> str:
        return self.__class__.__name__ + f"(mask_value={self.mask_value})"


class ShiftByConstant:
    """Shift Bytes by a constant value
    
    Args:
        shift: Integer value to add to bytes.
    """
    def __init__(self, shift: int) -> None:
        self.shift = shift

    @overload
    def __call__(self, input: Binary) -> IntBinary:
        ...
    @overload
    def __call__(self, input: BinarySample) -> IntBinarySample:
        ...
    def __call__(self, input):
        if isinstance(input, tuple):
            binary, metadata = input
            return F.to_tensor(binary, dtype=torch.int32) + self.shift, metadata
        return F.to_tensor(input, dtype=torch.int32) + self.shift
    
    def __repr__(self):
        return self.__class__.__name__ + f"(shift={self.shift})"