from secml_malware.attack.whitebox.c_discretized_bytes_evasion import CDiscreteBytesEvasion
from secml_malware.models import CClassifierEnd2EndMalware


class CPaddingEvasion(CDiscreteBytesEvasion):
	"""
	Constructs an attack object that append one byte at time.
	"""

	def __init__(
			self,
			end2end_model: CClassifierEnd2EndMalware,
			how_many: int,
			iterations: int = 100,
			is_debug: bool = False,
			random_init: bool = False,
			threshold: float = 0,
			penalty_regularizer: int = 0
	):
		"""
		Create the padding attack

		Parameters
		----------
		end2end_model : CClassifierEnd2EndMalware
			the target end-to-end model
		how_many : int
			how many padding byte to inject
		iterations : int, optional, default 100
			the number of iterations of the optimizer
		is_debug : bool, optional, default False
			if True, prints debug information during the optimization
		random_init : bool, optional, default False
			if True, it randomizes the locations set by index_to_perturb before starting the optimization
		threshold : float, optional, default 0
			the detection threshold to bypass. Default is 0
		penalty_regularizer : float, optional, default 0
			the reularization parameter, Default is 0
		"""
		super(CPaddingEvasion, self).__init__(
			end2end_model,
			index_to_perturb=[],
			iterations=iterations,
			is_debug=is_debug,
			random_init=random_init,
			threshold=threshold,
			penalty_regularizer=penalty_regularizer
		)
		self.how_many_padding_bytes = how_many

	def _run(self, x0, y0, x_init=None):
		invalid_value = 256 if self.invalid_pos == -1 else self.invalid_pos
		padding_positions = x0.find(x0 == invalid_value)
		if not padding_positions:
			self.indexes_to_perturb = []
		else:
			self.indexes_to_perturb = list(
				range(
					padding_positions[0],
					min(x0.size, padding_positions[0] + self.how_many_padding_bytes),
				)
			)
		return super(CPaddingEvasion, self)._run(x0, y0, x_init=x_init)
