import struct
from copy import deepcopy

import numpy as np
from secml.array import CArray

from secml_malware.attack.blackbox.c_blackbox_problem import CBlackBoxProblem
from secml_malware.attack.blackbox.c_wrapper_phi import CWrapperPhi
from secml_malware.utils.pe_operations import create_indexes_for_header_fields_manipulations


class CBlackBoxHeaderFieldsEvasionProblem(CBlackBoxProblem):
	"""
	Creates a black-box attack that perturbs 18 bytes inside the COFF and Optional Header of an executable.
	"""

	def __init__(
			self,
			model_wrapper: CWrapperPhi,
			population_size: int,
			iterations: int = 100,
			is_debug: bool = False,
			penalty_regularizer: float = 0,
			invalid_value: int = 256
	):
		"""
		Creates the attack.

		Parameters
		----------
		model_wrapper : CWrapperPhi
			the target model, wrapped inside a CWrapperPhi
		population_size : int
			the population size generated at each round by the genetic algorithm
		iterations : int, optional, default 100
			the total number of iterations, default 100
		is_debug : bool, optional, default False
			if True, it prints messages while optimizing. Default is False
		penalty_regularizer : float, optional, default 0
			the penalty regularizer for the file size constraint. Default is 0
		invalid_value : int, optional, default 256
			specifies which is the invalid value used as separator. Default is 256
		"""
		super(CBlackBoxHeaderFieldsEvasionProblem, self).__init__(model_wrapper,
																  latent_space_size=26,
																  iterations=iterations,
																  population_size=population_size,
																  is_debug=is_debug,
																  penalty_regularizer=penalty_regularizer)
		self.invalid_value = invalid_value
		self.indexes_to_perturb = None

	def init_starting_point(self, x: CArray, file_name=None) -> CArray:
		"""
		Initialize the problem, by setting the starting point.

		Parameters
		----------
		x : CArray
			the initial point

		Returns
		-------
		CArray
			the initial point (padded accordingly to remove trailing invalid values)
		"""
		pe_index = struct.unpack('<I', bytes(x[0, 60:64].tolist()[0]))[0]
		self.indexes_to_perturb = create_indexes_for_header_fields_manipulations(pe_index)
		self.latent_space_size = len(self.indexes_to_perturb)
		return super(CBlackBoxHeaderFieldsEvasionProblem, self).init_starting_point(x, file_name)

	def apply_feasible_manipulations(self, t: np.ndarray, x: CArray, file_name=None) -> CArray:
		"""
		Apply the header fields practical manipulation

		Parameters
		----------
		t : numpy array
			the vector of manipulations in [0,1]
		x : CArray
			the input space sample to perturb

		Returns
		-------
		CArray:
			the adversarial malware
		"""
		byte_values = (t * 255).astype(int)
		x_adv = deepcopy(x)
		x_adv[0, self.indexes_to_perturb] = CArray(byte_values)
		return CArray(x_adv)
