import struct

import torch
from secml.settings import SECML_PYTORCH_USE_CUDA

from secml_malware.attack.whitebox.c_discretized_bytes_evasion import CDiscreteBytesEvasion
from secml_malware.models import CClassifierEnd2EndMalware
from secml_malware.utils.pe_operations import create_indexes_for_header_fields_manipulations

use_cuda = torch.cuda.is_available() and SECML_PYTORCH_USE_CUDA


class CHeaderFieldsEvasion(CDiscreteBytesEvasion):
	"""
	Creates a black-box attack that perturbs 18 bytes inside the COFF and Optional Header of an executable.
	"""

	def __init__(
			self,
			end2end_model: CClassifierEnd2EndMalware,
			index_to_perturb: list = None,
			iterations: int = 100,
			is_debug: bool = False,
			random_init: bool = False,
			optimize_all_dos: bool = False,
			threshold: float = 0,
			penalty_regularizer: int = 0
	):
		"""
		Creates the evasion object

		Parameters
		----------
		end2end_model : CClassifierEnd2EndMalware
			the target end-to-end model
		iterations : int, optional, default 100
			the number of iterations of the optimizer
		is_debug : bool, optional, default False
			if True, prints debug information during the optimization
		random_init : bool, optional, default False
			if True, it randomizes the locations set by index_to_perturb before starting the optimization
		optimize_all_dos : bool, optional, default False
			if True, set as editable all the DOS header, not only the specified portion
		threshold : float, optional, default 0
			the detection threshold to bypass. Default is 0
		penalty_regularizer : float
			the regularization parameter. Default is 0
		"""

		super(CHeaderFieldsEvasion, self).__init__(
			end2end_model,
			index_to_perturb,
			iterations,
			is_debug,
			random_init,
			threshold,
			penalty_regularizer
		)
		self.optimize_all_dos = optimize_all_dos

	def _set_header_fields_indexes(self, x_init):
		pe_index = struct.unpack('<I', bytes(x_init[0, 60:64].astype(int).tolist()[0]))[0]
		self.indexes_to_perturb = create_indexes_for_header_fields_manipulations(pe_index)
		return self.indexes_to_perturb

	def _run(self, x0, y0, x_init=None):
		self._set_header_fields_indexes(x_init)
		return super(CHeaderFieldsEvasion, self)._run(x0, y0, x_init)
