import numpy as np
from secml.array import CArray

from secml_malware.attack.blackbox.c_blackbox_problem import CBlackBoxProblem
from secml_malware.attack.blackbox.c_wrapper_phi import CWrapperPhi
from secml_malware.utils.pe_operations import shift_section_by, shift_pe_header_by, create_int_list_from_x_adv


class CBlackBoxFormatExploitEvasionProblem(CBlackBoxProblem):
	def __init__(
			self,
			model_wrapper: CWrapperPhi,
			population_size: int,
			preferable_extension_amount: int = 0x200,
			pe_header_extension: int = 0x200,
			iterations: int = 100,
			is_debug: bool = False,
			penalty_regularizer: float = 0,
			invalid_value: int = 256
	):
		super(CBlackBoxFormatExploitEvasionProblem, self).__init__(model_wrapper,
																   latent_space_size=preferable_extension_amount + pe_header_extension,
																   iterations=iterations,
																   population_size=population_size, is_debug=is_debug,
																   penalty_regularizer=penalty_regularizer)
		self.preferable_extension_amount = preferable_extension_amount
		self.pe_header_extension = pe_header_extension
		self.invalid_value = invalid_value
		self.indexes_to_perturb = []

	def init_starting_point(self, x: CArray) -> CArray:
		"""
		Initialize the problem, by setting the starting point.

		Parameters
		----------
		x : CArray
			the initial point

		Returns
		-------
		CArray
			the initial point (padded accordingly to remove trailing invalid values)
		"""
		_, indexes_to_perturb = self._craft_perturbed_c_array(x)
		self.indexes_to_perturb = indexes_to_perturb
		self.latent_space_size = len(indexes_to_perturb)
		return super(CBlackBoxFormatExploitEvasionProblem, self).init_starting_point(x)

	def _craft_perturbed_c_array(self, x0):
		x_init = create_int_list_from_x_adv(x0, self.invalid_value, False)
		x_init, index_to_perturb_sections = shift_section_by(x_init,
															 preferable_extension_amount=self.preferable_extension_amount)
		x_init, index_to_perturb_pe = shift_pe_header_by(x_init, preferable_extension_amount=self.pe_header_extension)
		indexes_to_perturb = index_to_perturb_pe + [i + len(index_to_perturb_pe) for i in
													index_to_perturb_sections]
		return x_init, indexes_to_perturb

	def apply_feasible_manipulations(self, t, x: CArray) -> CArray:
		"""
		Apply the format exploit practical manipulation on the input sample

		Parameters
		----------
		t : CArray
			the vector of manipulations in [0,1]
		x : CArray
			the input space sample to perturb

		Returns
		-------
		CArray:
			the adversarial malware
		"""
		byte_values = (t * 255).astype(np.int)
		x_adv, _ = self._craft_perturbed_c_array(x)
		for i, index in enumerate(self.indexes_to_perturb):
			x_adv[index] = byte_values[i]
		x_adv = CArray(x_adv)
		x_adv = x_adv.reshape((1, x_adv.shape[-1]))
		return CArray(x_adv)


class CBlackBoxContentShiftingEvasionProblem(CBlackBoxFormatExploitEvasionProblem):
	def __init__(
			self,
			model_wrapper: CWrapperPhi,
			population_size: int,
			bytes_to_inject: int = 0x200,
			iterations: int = 100,
			is_debug: bool = False,
			penalty_regularizer: float = 0,
			invalid_value: int = 256
	):
		super(CBlackBoxContentShiftingEvasionProblem, self).__init__(model_wrapper, population_size, bytes_to_inject, 0,
																	 iterations, is_debug, penalty_regularizer,
																	 invalid_value)


class CBlackBoxContentDOSExtensionProblem(CBlackBoxFormatExploitEvasionProblem):
	def __init__(
			self,
			model_wrapper: CWrapperPhi,
			population_size: int,
			bytes_to_inject: int = 0x200,
			iterations: int = 100,
			is_debug: bool = False,
			penalty_regularizer: float = 0,
			invalid_value: int = 256
	):
		super(CBlackBoxContentDOSExtensionProblem, self).__init__(model_wrapper, population_size, 0, bytes_to_inject,
																  iterations, is_debug, penalty_regularizer,
																  invalid_value)
