import torch
import numpy as np
from secml.array import CArray
from secml.settings import SECML_PYTORCH_USE_CUDA
from secml_malware.attack.whitebox import CEnd2EndMalwareEvasion
from secml_malware.models import CClassifierEnd2EndMalware

use_cuda = torch.cuda.is_available() and SECML_PYTORCH_USE_CUDA
use_mps = torch.backends.mps.is_available()

class CFastGradientSignMethodEvasion(CEnd2EndMalwareEvasion):
	"""Creates the basic attack that implements the Fast Gradient Sign Method for the Windows malware domain.
	The original attack has been proposed by Goodfellow et al. (https://arxiv.org/abs/1412.6572)
	"""
	def __init__(
		self,
		end2end_model: CClassifierEnd2EndMalware,
		indexes_to_perturb: list,
		epsilon: float,
		iterations: int = 100,
		is_debug: bool = False,
		random_init: bool = False,
		threshold : float = 0.5,
		penalty_regularizer : float = 0,
		p_norm : float = np.infty,
		store_checkpoints : int = None
	):
		super(CFastGradientSignMethodEvasion, self).__init__(
			end2end_model=end2end_model,
			indexes_to_perturb=indexes_to_perturb,
			iterations=iterations,
			is_debug=is_debug,
			random_init=random_init,
			threshold=threshold,
			penalty_regularizer=penalty_regularizer,
			store_checkpoints = store_checkpoints
		)
		self.epsilon = epsilon
		self.p_norm = p_norm

	def compute_penalty_term(self, original_x: CArray, adv_x: CArray, par: float):
		penalty_term = torch.autograd.Variable(torch.tensor([0]))
		if use_cuda:
			penalty_term = penalty_term.cuda()
		if use_mps:
			#penalty_term = torch.tensor(penalty_term, device='mps')
			penalty_term = penalty_term.to(torch.device('mps'))
		return penalty_term

	def loss_function_gradient(self, original_x : CArray, adv_x : CArray, penalty_term : torch.Tensor):
		y = self.classifier.embedding_predict(adv_x)
		malware_class = torch.ones(y.shape)
		if use_cuda:
			malware_class = malware_class.cuda()
		if use_mps:
			#malware_class = torch.tensor(malware_class, device='mps')
			malware_class = malware_class.to(torch.device('mps'))
		loss = torch.nn.functional.binary_cross_entropy(y, malware_class)
		g = torch.autograd.grad(loss, adv_x)[0]
		g = torch.transpose(g, 1, 2)[0]
		return g

	def optimization_solver(self, E, gradient_f, index_to_consider, x_init):
		gradient_result = self._internal_fsgm_solver(gradient_f).transpose(0,1)
		x_init[0, :, index_to_consider] = x_init[0, :, index_to_consider] + gradient_result[:, index_to_consider]
		return x_init

	def _internal_fsgm_solver(self, gradient_f):
		if(use_mps): torch.set_default_device('mps')
		if(use_cuda): torch.set_default_device('cuda')
		g = gradient_f / torch.norm(gradient_f) if not torch.equal(torch.zeros(gradient_f.shape), gradient_f) else torch.zeros(gradient_f.shape)
		if self.p_norm == 2:
			return  self.epsilon * g
		elif self.p_norm == np.infty:
			return self.epsilon * torch.sign(g)
		raise NotImplementedError(f"{self.p_norm}-norm not yet implemented")

	def infer_step(self, x_init):
		confidence = self.classifier.embedding_predict(x_init)
		return confidence.item()

	def invert_feature_mapping(self, x, x_adv):
		E = self._get_embedded_byte_matrix()
		byte_malware = torch.zeros(x.shape[-1]) + self.classifier.get_embedding_value()
		E[self.invalid_pos] = torch.tensor([np.infty for _ in range(self.classifier.get_embedding_size())])
		reconstructed_x = x.tondarray()

		for i in self.indexes_to_perturb:
			x_i = x_adv[0, :, i]
			byte_to_consider = torch.tensor([torch.norm(e - x_i, p=2) for e in E]).argmin()
			byte_malware[i] = byte_to_consider

		#if(use_mps): byte_malware = byte_malware.to(torch.device('cpu'))
		#if(use_cuda): byte_malware = byte_malware.to(torch.device('cpu'))
		reconstructed_x[0, self.indexes_to_perturb] = byte_malware[self.indexes_to_perturb].detach().cpu().numpy()
		return CArray(reconstructed_x)

	def apply_feature_mapping(self, x : CArray):
		return self.classifier.embed(x.tondarray())
