import struct
import numpy as np
import torch
from secml.settings import SECML_PYTORCH_USE_CUDA

from secml_malware.attack.whitebox.c_discretized_bytes_evasion import CDiscreteBytesEvasion
from secml_malware.models import CClassifierEnd2EndMalware

use_cuda = torch.cuda.is_available() and SECML_PYTORCH_USE_CUDA


class CHeaderEvasion(CDiscreteBytesEvasion):
	"""Creates the attack that perturbs the header of a Windows PE malware.
	"""

	def __init__(
			self,
			end2end_model: CClassifierEnd2EndMalware,
			index_to_perturb: list = None,
			iterations: int = 100,
			is_debug: bool = False,
			random_init: bool = False,
			optimize_all_dos: bool = False,
			threshold: float = 0,
			penalty_regularizer: int = 0
	):
		"""
		Creates the evasion object

		Parameters
		----------
		end2end_model : CClassifierEnd2EndMalware
			the target end-to-end model
		index_to_perturb : list
			a list containing the index to perturb inside the samples
		iterations : int, optional, default 100
			the number of iterations of the optimizer
		is_debug : bool, optional, default False
			if True, prints debug information during the optimization
		random_init : bool, optional, default False
			if True, it randomizes the locations set by index_to_perturb before starting the optimization
		optimize_all_dos : bool, optional, default False
			if True, set as editable all the DOS header, not only the specified portion
		threshold : float, optional, default 0
			the detection threshold to bypass. Default is 0
		penalty_regularizer : float
			the reularization parameter, Default is 0
		"""

		if index_to_perturb is None:
			index_to_perturb = [i for i in range(2, 0x3C)]
		super(CHeaderEvasion, self).__init__(
			end2end_model,
			index_to_perturb,
			iterations,
			is_debug,
			random_init,
			threshold,
			penalty_regularizer
		)
		self.optimize_all_dos = optimize_all_dos

	def _set_dos_indexes(self, x_init):
		if self.optimize_all_dos:
			pe_position = x_init[0x3C:0x40].tondarray().astype(np.uint16)[0]
			if self.shift_values:
				pe_position = np.array([p - 1 for p in pe_position])
			pe_position = struct.unpack("<I", bytes(pe_position.astype(np.uint8)))[0]
			self.indexes_to_perturb = [i for i in range(2, 0x3C)] + [
				i for i in range(0x40, pe_position)
			]
			self._how_many = len(self.indexes_to_perturb)
			if self.is_debug:
				print(f"PE POSITION: {pe_position}, perturbing {self._how_many}")

	def _run(self, x0, y0, x_init=None):
		self._set_dos_indexes(x_init)
		return super(CHeaderEvasion, self)._run(x0, y0, x_init)
