import lief
import numpy as np

from secml_malware.attack.whitebox.c_fast_gradient_sign_evasion import CFastGradientSignMethodEvasion
from secml_malware.models import CClassifierEnd2EndMalware
from secml_malware.utils.pe_operations import create_int_list_from_x_adv


class CKreukEvasion(CFastGradientSignMethodEvasion):
    def __init__(
            self,
            end2end_model: CClassifierEnd2EndMalware,
            how_many_padding_bytes: int,
            epsilon: float,
            iterations: int = 100,
            is_debug: bool = False,
            threshold: float = 0.5,
            p_norm: float = np.infty,
            compute_slack: bool = True,
            store_checkpoints: int = None
    ):
        """
		Create the padding attack by Kreuk et al. https://arxiv.org/abs/1802.04528

		Parameters
		----------
		end2end_model : CClassifierEnd2EndMalware
			the target end-to-end model
		how_many_padding_bytes: int, optional, default 512
			how many padding bytes
		epsilon : float
			the distortion amount
		iterations : int, optional, default 100
			the number of iterations of the optimizer
		is_debug : bool, optional, default False
			if True, prints debug information during the optimization
		threshold : float, optional, default 0
			the detection threshold to bypass. Default is 0
		p_norm : float, optional, default np.infty:
			the norm to use for compute the attack
		compute_slack : bool, optional, default True
			if True, uses also the slack space between sections. Default True
		store_checkpoints: int, optional, default None
			if set, it reconstruct the samples after the number of iteration specified. Default None
		"""
        super(CKreukEvasion, self).__init__(
            end2end_model=end2end_model,
            indexes_to_perturb=[],
            epsilon=epsilon,
            iterations=iterations,
            is_debug=is_debug,
            threshold=threshold,
            penalty_regularizer=0,
            p_norm=p_norm,
            store_checkpoints=store_checkpoints
        )
        self.how_many_padding_bytes = how_many_padding_bytes
        self.compute_slack = compute_slack

    def _run(self, x0, y0, x_init=None):
        padding = self._create_pading_indexes(x0)
        # self.indexes_to_perturb = self._create_slack_indexes(x0) + padding if self.compute_slack else padding
        self.indexes_to_perturb = self.indexes_to_perturb + padding if self.compute_slack else padding
        #print(self.indexes_to_perturb)
        return super(CFastGradientSignMethodEvasion, self)._run(x0, y0, x_init=x_init)

    def create_slack_indexes2(self, file_path):

        try:
            liefpe = lief.PE.parse(file_path)
        # print(liefpe)


            window_input_length = self.classifier.get_input_max_length()
            all_slack_space = []
            for s in liefpe.sections:
                if s.size > s.virtual_size:
                    all_slack_space.extend(list(range(min(window_input_length, s.offset + s.virtual_size),
                                                      min(window_input_length, s.offset + s.size))))
            # print("Slack spaces: ", all_slack_space)
            return all_slack_space

        except:
            print("Lief couldn't parse the file!")
            return []

    def _create_slack_indexes(self, x0):
        x_bytes = create_int_list_from_x_adv(x0, self.classifier.get_embedding_value(),
                                             self.classifier.get_is_shifting_values())
        try:
            liefpe = lief.PE.parse(x_bytes)
        except:
            return []
        window_input_length = self.classifier.get_input_max_length()
        all_slack_space = []
        for s in liefpe.sections:
            if s.size > s.virtual_size:
                all_slack_space.extend(list(range(min(window_input_length, s.offset + s.virtual_size),
                                                  min(window_input_length, s.offset + s.size))))
        return all_slack_space

    def _create_pading_indexes(self, x0):
        invalid_value = 256 if self.invalid_pos == -1 else self.invalid_pos
        padding_positions = x0.find(x0 == invalid_value)
        if not padding_positions:
            return []
        else:
            return list(
                range(
                    padding_positions[0],
                    min(x0.size, padding_positions[0] + self.how_many_padding_bytes),
                )
            )
