import numpy
from secml.array import CArray

from secml_malware.attack.whitebox.c_discretized_bytes_evasion import CDiscreteBytesEvasion
from secml_malware.models import CClassifierEnd2EndMalware, End2EndModel
from secml_malware.utils.pe_operations import shift_section_by, shift_pe_header_by, create_int_list_from_x_adv


class CFormatExploitEvasion(CDiscreteBytesEvasion):

	def __init__(self,
				 end2end_model: CClassifierEnd2EndMalware,
				 preferable_extension_amount: int = 0x200,
				 pe_header_extension: int = 0x200,
				 iterations: int = 100,
				 is_debug: bool = False,
				 random_init: bool = False,
				 threshold: float = 0.5,
				 penalty_regularizer: float = 0,
				 chunk_hyper_parameter: int = None,
				 ):
		"""
		Creates the evasion object

		Parameters
		----------
		end2end_model : CClassifierEnd2EndMalware
			the target end-to-end model
		preferable_extension_amount : int, optional, default 512
			the number of bytes to inject before the first section, modulo file alignment. Default 512.
		pe_header_extension : int, optional, default 512
			the number of bytes to inject as new DOS header, modulo file alignment. Default 512.
		iterations : int, optional, default 100
			the number of iterations of the optimizer
		is_debug : bool, optional, default False
			if True, prints debug information during the optimization
		random_init : bool, optional, default False
			if True, it randomizes the locations set by index_to_perturb before starting the optimization
		threshold : float, optional, default 0
			the detection threshold to bypass. Default is 0
		penalty_regularizer : float
			the reularization parameter, Default is 0
		chunk_hyper_parameter : int, optional, default 256
			how many bytes to optimize at each round. Default is 256
		"""
		super(CFormatExploitEvasion, self).__init__(
			end2end_model=end2end_model,
			index_to_perturb=[],
			iterations=iterations,
			is_debug=is_debug,
			random_init=random_init,
			threshold=threshold,
			penalty_regularizer=penalty_regularizer,
			chunk_hyper_parameter=chunk_hyper_parameter
		)
		self.preferable_extension_amount = preferable_extension_amount
		self.pe_header_extension = pe_header_extension

	def _run(self, x0, y0, x_init=None):
		file_name = None
		if isinstance(x0, list):
			x_temp = x0[0]
			file_name = x0[1]
			#x0 = x_temp
			print(file_name)
		x_init, _ = self._craft_perturbed_c_array(x_temp, file_name)
		#print(x_init.shape)
		x_temp = x_temp.reshape((1, -1))
		#print(x_temp.shape)
		return super(CFormatExploitEvasion, self)._run(x_temp, y0, x_init)

	def _craft_perturbed_c_array(self, x0: CArray, file_name=None):
		x_init, indexes_to_perturb = self._generate_list_adv_example(x0, file_name)
		self.indexes_to_perturb = indexes_to_perturb
		x_init = CArray(
			[
				End2EndModel.list_to_numpy(
					x_init,
					self.classifier.get_input_max_length(),
					self.classifier.get_embedding_value(),
					self.classifier.get_is_shifting_values())
			]
		)
		return x_init, indexes_to_perturb

	def _generate_list_adv_example(self, x0, file_name=None):
		x_init = create_int_list_from_x_adv(x0, self.classifier.get_embedding_value(),
											self.classifier.get_is_shifting_values())
		x_init, index_to_perturb_sections = shift_section_by(x_init,
															 preferable_extension_amount=self.preferable_extension_amount)
		x_init, index_to_perturb_pe = shift_pe_header_by(x_init, preferable_extension_amount=self.pe_header_extension)
		indexes_to_perturb = index_to_perturb_pe + [i + len(index_to_perturb_pe) for i in
													index_to_perturb_sections]
		return x_init, indexes_to_perturb

	def create_real_sample_from_adv(self, original_file_path: str, x_adv: CArray,
									new_file_path: str = None) -> bytearray:
		"""
		Create a real adversarial example

		Parameters
		----------
		original_file_path : str
			the original malware sample
		x_adv : CArray
			the perturbed malware sample, as created by the optimizer
		new_file_path : str, optional, default None
			the path where to save the adversarial malware. Leave None to not save the result to disk

		Returns
		-------
		bytearray
			the adversarial malware, as string of bytes
		"""
		with open(original_file_path, 'rb') as f:
			code = bytearray(f.read())
		original_x = CArray([numpy.frombuffer(code, dtype=numpy.uint8).astype(numpy.uint16)])
		if self.classifier.get_is_shifting_values():
			original_x += self.classifier.get_is_shifting_values()
		x_init, index_to_perturb = self._generate_list_adv_example(original_x, original_file_path)
		x_init = CArray([x_init]).astype(numpy.uint8)
		x_init[0, index_to_perturb] = x_adv[0, index_to_perturb] - self.classifier.get_is_shifting_values()
		x_real = x_init[0, :].tolist()[0]
		x_real_adv = b''.join([bytes([i]) for i in x_real])
		if new_file_path:
			with open(new_file_path, 'wb') as f:
				f.write(x_real_adv)
		return x_real_adv
