from secml_malware.attack.whitebox import CFormatExploitEvasion
from secml_malware.models import CClassifierEnd2EndMalware
from secml_malware.utils.pe_operations import create_int_list_from_x_adv, shift_pe_header_by, shift_pe_header_by_for_DOS_extend


class CExtendDOSEvasion(CFormatExploitEvasion):
	"""
	DOS header extension attack
	"""
	def __init__(self,
			end2end_model : CClassifierEnd2EndMalware,
			pe_header_extension : int = 0x200,
			iterations : int =100,
			is_debug: bool = False,
			random_init: bool = False,
			threshold: float = 0.5,
			penalty_regularizer: float = 0,
			chunk_hyper_parameter : int = None,
	):
		"""
		Create the DOS header extension attack

		Parameters
		----------
		end2end_model : CClassifierEnd2EndMalware
			the target end-to-end model
		pe_header_extension: int, optional, default 512
			how many bytes to inject, modulo the file alignment
		iterations : int, optional, default 100
			the number of iterations of the optimizer
		is_debug : bool, optional, default False
			if True, prints debug information during the optimization
		random_init : bool, optional, default False
			if True, it randomizes the locations set by index_to_perturb before starting the optimization
		threshold : float, optional, default 0
			the detection threshold to bypass. Default is 0
		penalty_regularizer : float
			the reularization parameter, Default is 0
		chunk_hyper_parameter : int, optional, default 256
			how many bytes to optimize at each round. Default is 256
		"""
		super(CExtendDOSEvasion, self).__init__(
			end2end_model=end2end_model,
			pe_header_extension=pe_header_extension,
			preferable_extension_amount=0,
			iterations=iterations,
			is_debug=is_debug,
			random_init=random_init,
			threshold=threshold,
			penalty_regularizer=penalty_regularizer,
			chunk_hyper_parameter=chunk_hyper_parameter
		)
		self.max_header_size = 0x1000

	def _generate_list_adv_example(self, x0, file_name):
		x_init = create_int_list_from_x_adv(x0, self.classifier.get_embedding_value(), self.classifier.get_is_shifting_values())
		x_init, index_to_perturb_pe = shift_pe_header_by_for_DOS_extend(x_init, file_name, preferable_extension_amount=self.pe_header_extension)
		#print("Indexes to perturb: ", index_to_perturb_pe)
		return x_init, index_to_perturb_pe

	def run(self, x0, y0, x_init=None):
		print("Running from extend_dos_evasion!")
		return super(CExtendDOSEvasion, self)._run(x0, y0, x_init)

