import pefile
import torch
import warnings
import os
from pathlib import Path
from tempfile import TemporaryDirectory
from multiprocessing import TimeoutError

from ..types import ByteBinarySample, IntBinarySample

from ..metadata import (Metadata, AddrRanges, AddrRangesUpdater, ExeSectionRanges,
                        ExeSectionRangesUpdater, HeaderSize, HeaderSizeUpdater)
from .util import temppath, run_in_process, clamp
from ..ghidra_interface import start_ghidra, ghidra_running

from typing import Optional, Sequence


def zero_pe_header(binary: bytes, metadata: Optional[Metadata] = None) -> ByteBinarySample:
    try:
        if metadata and "header_size" in metadata:
            header_size = metadata["header_size"]
        else:
            pe = pefile.PE(data=binary)
            header_size = len(pe.header)
        zeroes = bytes(header_size)
        binary = zeroes + binary[header_size:]
        
        if metadata:
            metadata.replace_bytes(range(0, header_size), zeroes)
    except pefile.PEFormatError as e:
        warnings.warn(repr(e))
    
    return binary, metadata


def remove_pe_header(binary: bytes, metadata: Optional[Metadata] = None) -> ByteBinarySample:
    try:
        if metadata and "header_size" in metadata:
            header_size = metadata["header_size"]
        else:
            pe = pefile.PE(data=binary)
            header_size = len(pe.header)
        binary = binary[header_size:]
    
        if metadata:
            metadata.delete_bytes(range(0, header_size))
    except pefile.PEFormatError as e:
        warnings.warn(repr(e))
    
    return binary, metadata


def _get_insn_addr_ranges_impl(bin_path: str, proj_dir: str, write_stdout: bool, 
                               write_stderr: bool) -> AddrRanges:
    
    if not ghidra_running():
        start_ghidra()

    from java.lang import System
    from java.io import OutputStream, PrintStream

    if not write_stdout:
        # Redirect stdout to null output stream
        ps_out = PrintStream(OutputStream.nullOutputStream())
        old_out = System.out
        System.setOut(ps_out)
    
    if not write_stderr:
        # Redirect stderr to null output stream
        ps_err = PrintStream(OutputStream.nullOutputStream())
        old_err = System.err        
        System.setErr(ps_err)
    
    from ghidra.base.project import GhidraProject
    
    # Absolute path required when calling Ghidra below
    bin_path = os.path.abspath(bin_path)
    proj_dir = os.path.abspath(proj_dir)
    dir_name, file_name = os.path.split(bin_path)
    
    insn_addr_ranges = AddrRanges()
    
    gp = GhidraProject.createProject(proj_dir, file_name, True)
    prog = gp.importProgram(Path(bin_path))
    gp.analyze(prog)

    insns = prog.getListing().getInstructions(True)
    mmap = prog.getMemory()
    insn_addr_ranges = AddrRanges()
    for insn in insns:
        addr_start = insn.getAddress()
        file_start = int(mmap.getAddressSourceInfo(addr_start).getFileOffset())
        addr_end = addr_start.add(insn.getLength())
        file_end = int(mmap.getAddressSourceInfo(addr_end).getFileOffset())
        insn_addr_ranges.append((file_start, file_end))

    gp.close()
    
    return insn_addr_ranges


def get_insn_addr_ranges(bin_path: str, write_stdout: bool = True, 
                         write_stderr: bool = True) -> AddrRanges:
    # Return empty address ranges if Ghidra analysis fails
    insn_addr_ranges = AddrRanges()
    
    # Create temporary project to analyze this binary (will be deleted on close)
    timeout = int(os.environ.get('GHIDRA_ANALYSIS_TIMEOUT', -1))
    timeout = timeout if timeout >= 0 else None
    
    with TemporaryDirectory() as proj_dir:
        try:
            insn_addr_ranges = run_in_process(_get_insn_addr_ranges_impl, timeout, args=(bin_path, proj_dir, write_stdout, write_stderr))
            #insn_addr_ranges = _get_insn_addr_ranges_impl(bin_path, proj_dir, write_stdout, write_stderr)
        except TimeoutError as ex:       
            warnings.warn(f"Ghidra analysis for {bin_path} timed out after {timeout} seconds")
        except Exception as ex:
            warnings.warn(f"Ghidra analysis failed for {bin_path}\n {type(ex)}: {str(ex)}")
    
    return insn_addr_ranges


def add_insn_addr_ranges(
    binary: bytes,
    metadata: Optional[Metadata] = None, 
    write_stdout: bool = True, 
    write_stderr: bool = True
) -> ByteBinarySample:
    """Transformation that adds address ranges of instructions to the metadata

    Note:
    Calls Ghidra via JPype to disassemble the binary. This function will start a headless Ghidra instance if it is 
    not already running.

    Args:
        binary: A binary.
        metadata: Metadata associated with the binary. If 'binary_path' is present, the corresponding file is used 
            when using Ghidra. Otherwise a temporary file is created.
        write_stdout: Whether to write messages from Ghidra analysis to standard output.
        write_stderr: Whether to write error messages from Ghidra analysis to standard error.
    
    Returns:
        Original binary and transformed metadata (new entry added for 'insn_addr')
    """
    if not metadata:
        metadata = Metadata()
    binary_path = metadata.get('binary_path', None)
    
    if not binary_path:
        with temppath(binary) as binary_path:
            insn_addr_ranges = get_insn_addr_ranges(binary_path, write_stdout=write_stdout, 
                                                    write_stderr=write_stderr)
    else:
        insn_addr_ranges = get_insn_addr_ranges(binary_path, write_stdout=write_stdout, 
                                                write_stderr=write_stderr)

    metadata['insn_addr'] = insn_addr_ranges.to_tensor(len(binary), sparse = True)
    metadata.set_updater('insn_addr', AddrRangesUpdater())

    return binary, metadata


def get_exe_section_ranges(binary: bytes) -> ExeSectionRanges:
    """Get text section ranges in the PE file
    
    Args:
        binary: binary file to analyze
    
    Returns:
        A list of address ranges of the form `(start, end)`. Each range includes `start` and excludes `end`. 
        Addresses are encoded as base-10 integers.
    """
    def is_code(section: pefile.SectionStructure) -> bool:
        characteristics = getattr(section, 'Characteristics')
        # https://reverseengineering.stackexchange.com/questions/11311/how-to-recognize-pe-sections-containing-code
        if characteristics & 0x20000000 > 0:
            return True
        return False

    pe = pefile.PE(data=binary)
    exe_section_ranges = ExeSectionRanges()
    for section in pe.sections:
        if is_code(section):
            start = section.PointerToRawData
            # Guard against malformed PE file
            start = clamp(start, 0, len(binary))
            end = start + section.SizeOfRawData
            end = clamp(end, 0, len(binary))
            exe_section_ranges.append((start, end))
    return exe_section_ranges


def add_exe_section_ranges(binary: bytes, metadata: Optional[Metadata] = None) -> ByteBinarySample:
    try:
        if not metadata:
            metadata = Metadata()
        exe_section_ranges = get_exe_section_ranges(binary)

        metadata['exe_section'] = exe_section_ranges.to_tensor(len(binary), sparse = True)
        metadata.set_updater('exe_section', ExeSectionRangesUpdater())
    except pefile.PEFormatError as e:
        warnings.warn(repr(e))

    return binary, metadata


def get_header_size(binary: bytes) -> HeaderSize:
    pe = pefile.PE(data=binary)
    header_size = HeaderSize(len(pe.header))
    return header_size


def add_header_size(binary: bytes, metadata: Optional[Metadata] = None) -> ByteBinarySample:
    try:
        if not metadata:
            metadata = Metadata()
        header_size = get_header_size(binary)
        metadata['header_size'] = header_size
        metadata.set_updater('header_size', HeaderSizeUpdater())
    except pefile.PEFormatError as e:
        warnings.warn(repr(e))
    
    return binary, metadata


def apply_mask(binary: bytes, index: Sequence[int], mask_value: int = 256, 
               metadata: Optional[Metadata] = None) -> IntBinarySample:
    binary = torch.frombuffer(binary, dtype=torch.uint8).int()
    index = torch.as_tensor(index, dtype=int)
    binary[index] = mask_value
    if metadata:
        for i in index:
            metadata.mask_bytes(i)
    return binary, metadata


def trim(binary: bytes, length: int, metadata: Optional[Metadata] = None) -> ByteBinarySample: 
    # Clip oversized binary
    trim_size = len(binary) - length
    if trim_size > 0:
        binary = binary[:length]
        if metadata:
            metadata.delete_bytes(range(length, length + trim_size))
    return binary, metadata
