import numpy as np
import torch
from secml.array import CArray
from secml.settings import SECML_PYTORCH_USE_CUDA

from secml_malware.attack.whitebox import CEnd2EndMalwareEvasion
from secml_malware.models import CClassifierEnd2EndMalware

use_cuda = torch.cuda.is_available() and SECML_PYTORCH_USE_CUDA
use_mps = torch.backends.mps.is_available()


def gradient_search(
		start_byte: int,
		gradient: torch.Tensor,
		embedding_bytes: torch.Tensor,
		invalid_val=np.infty,
		invalid_pos: int = -1
):
	"""
	Given the starting byte, the gradient and the embedding map,it returns a list of distances

	Parameters
	----------
	start_byte : int
		the starting byte for the search
	gradient : torch.Tensor
		the gradient
	embedding_bytes : torch.Tensor
		the embedding matrix with all the byte embedded
	invalid_val : optional, default np.infty
		the invalid value to use. Default np.infty
	invalid_pos : int, optional, default -1
		the position of the padding value.
	Returns
	-------

	"""
	if torch.equal(gradient, torch.zeros(gradient.shape)):
		invalid_distances = torch.tensor([invalid_val] * len(embedding_bytes))
		if use_cuda:
			invalid_distances = invalid_distances.cuda()
		elif use_mps:
			invalid_distances = invalid_distances.to(torch.device('mps'))
		return invalid_distances
	distance = torch.zeros(257)
	start_emb_byte = embedding_bytes[start_byte]
	gs = -gradient / torch.norm(gradient)
	if use_cuda:
		distance = distance.cuda()
	elif use_mps:
		distance = distance.to(torch.device('mps'))
	for i, b in enumerate(embedding_bytes):
		bts = b - start_emb_byte
		s_i = torch.dot(gs, bts)
		if s_i <= 0:
			distance[i] = invalid_val
		else:
			d_i = torch.norm(b - (start_emb_byte + s_i * gs))
			distance[i] = d_i
	distance[invalid_pos] = invalid_val
	return distance


class CDiscreteBytesEvasion(CEnd2EndMalwareEvasion):
	"""
	Creates the attack that perturbs the header of a Windows PE malware.
	"""

	def __init__(
			self,
			end2end_model: CClassifierEnd2EndMalware,
			index_to_perturb: list,
			iterations: int = 100,
			is_debug: bool = False,
			random_init: bool = False,
			threshold: float = 0.5,
			penalty_regularizer: float = 0,
			chunk_hyper_parameter: int = 256
	):
		"""
		Creates the evasion object

		Parameters
		----------
		end2end_model : CClassifierEnd2EndMalware
			the target end-to-end model
		index_to_perturb : list
			a list containing the index to perturb inside the samples
		iterations : int, optional, default 100
			the number of iterations of the optimizer
		is_debug : bool, optional, default False
			if True, prints debug information during the optimization
		random_init : bool, optional, default False
			if True, it randomizes the locations set by index_to_perturb before starting the optimization
		threshold : float, optional, default 0
			the detection threshold to bypass. Default is 0
		penalty_regularizer : float
			the reularization parameter, Default is 0
		chunk_hyper_parameter : int, optional, default 256
			how many bytes to optimize at each round. Default is 256
		"""
		super(CDiscreteBytesEvasion, self).__init__(
			end2end_model=end2end_model,
			indexes_to_perturb=index_to_perturb,
			iterations=iterations,
			is_debug=is_debug,
			random_init=random_init,
			threshold=threshold,
			penalty_regularizer=penalty_regularizer,
		)
		# if chunk_hyper_parameter is None:
		# 	chunk_hyper_parameter = 256
		self.chunk_hyper_parameter = chunk_hyper_parameter

	def compute_penalty_term(self, original_x: CArray, adv_x: CArray, par: float) -> torch.Tensor:
		"""
		Computes the penalty term as torch node

		Parameters
		----------
		original_x : CArray
			the original malware sample
		adv_x : CArray
			the adversarial malware version
		par : float
			the regularization parameter
		Returns
		-------
		torch.Tensor
			a torch node of the graph containing the penalty
		"""
		emb_original_x = self.classifier.embed(original_x.tondarray())
		emb_adv_x = self.classifier.embed(adv_x.tondarray())
		lambda_variable = torch.autograd.Variable(torch.Tensor([par]), requires_grad=True)
		penalty_term = torch.autograd.Variable(emb_original_x - emb_adv_x, requires_grad=True)
		if use_mps:
			lambda_variable = lambda_variable.to(torch.device('mps'))
			penalty_term = penalty_term.to(torch.device('mps'))
		if use_cuda:
			lambda_variable = lambda_variable.to(torch.device('cuda'))
			penalty_term = penalty_term.to(torch.device('cuda'))
		penalty_term = penalty_term.norm() * lambda_variable
		return penalty_term

	def loss_function_gradient(self, original_x: CArray, adv_x: CArray, penalty_term: torch.Tensor) -> torch.Tensor:
		"""
		Compute the gradient of the loss function of the target model

		Parameters
		----------
		original_x : CArray
			the original malware sample
		adv_x : CArray
			the adversarial malware sample
		penalty_term : torch.Tensor
			the penalty term

		Returns
		-------
		torch.Tensor
			the gradient of the model w.r.t. input on the embedding layer
		"""
		emb_adv_x = self.classifier.embed(adv_x.tondarray())
		if use_cuda:
			penalty_term = penalty_term.cuda()
		elif use_mps:
			penalty_term = penalty_term.to(torch.device('mps'))
		y = self.classifier.embedding_predict(emb_adv_x)
		output = y + penalty_term
		if use_cuda:
			output = output.cuda()
		elif use_mps:
			output = output.to(torch.device('mps'))
		g = torch.autograd.grad(output, emb_adv_x)[0]
		g = torch.transpose(g, 1, 2)[0]
		return g

	def optimization_solver(self, E: torch.Tensor, gradient_f: torch.Tensor, index_to_consider: list,
							x_init: CArray) -> CArray:
		"""
		Optimizes the end-to-end evasion

		Parameters
		----------
		E : torch.Tensor
			the embedding matrix E, with all the embedded values
		gradient_f : torch.Tensor
			the gradient of the function w.r.t. the embedding
		index_to_consider : list
			the list of indexes to perturb
		x_init : CArray
			the input sample to manipulate

		Returns
		-------
		CArray
			the adversarial malware
		"""
		if self.chunk_hyper_parameter:
			best_indexes = gradient_f[index_to_consider].norm(dim=1).argsort(descending=True)[
						   :self.chunk_hyper_parameter]
			best_indexes = [index_to_consider[i] for i in best_indexes]
		else:
			best_indexes = index_to_consider
		results = [self._find_byte_using_gradient(E, gradient_f, i, x_init) for i in best_indexes]
		for res in results:
			if not torch.equal(res[1], self._invalid_value):
				x_init[res[2]] = res[0].item()
		return x_init

	def _find_byte_using_gradient(self, E: torch.Tensor, gradient_f: torch.Tensor, i: int, x_init: CArray) -> (
			int, torch.Tensor, int):
		gradient_f_i = gradient_f[i]
		x_i = x_init[i].tondarray().astype(np.uint16).ravel().item()
		distances = gradient_search(
			x_i,
			gradient_f_i,
			E,
			invalid_val=self._invalid_value,
			invalid_pos=self.invalid_pos,
		)
		min_value, byte_to_choose = torch.min(distances, dim=0, keepdim=True)
		return byte_to_choose, min_value, i

	def infer_step(self, x_init: CArray) -> float:
		"""
		Return prediction w.r.t. the malware class

		Parameters
		----------
		x_init : CArray
			the sample to use for the forward step

		Returns
		-------
		float
			the malware score
		"""
		_, confidence = self.classifier.predict(x_init, return_decision_function=True)
		return confidence[1].item()

	def invert_feature_mapping(self, x: CArray, x_adv: CArray) -> CArray:
		"""
		Invert the feature mapping

		Parameters
		----------
		x : CArray
			the original sample
		x_adv : CArray
			the adversarial sample

		Returns
		-------
		CArray
			the inverted feature mapping of the adv sample
		"""
		return x_adv

	def apply_feature_mapping(self, x: CArray) -> CArray:
		"""
		Applies the feature extraction

		Parameters
		----------
		x : CArray
			the input malware sample

		Returns
		-------
		CArray
			the feature vector
		"""
		return x
