import copy
from abc import abstractmethod

import numpy as np
import torch
from secml.adv.attacks import CAttackEvasion
from secml.array import CArray
from secml.array.c_dense import CDense
from secml.settings import SECML_PYTORCH_USE_CUDA

from secml_malware.models import CClassifierEnd2EndMalware

use_cuda = torch.cuda.is_available() and SECML_PYTORCH_USE_CUDA
use_mps = torch.backends.mps.is_available()


class CEnd2EndMalwareEvasion(CAttackEvasion):
	"""
	Base abstract class for implementing end-to-end evasion attacks against malware detectors.
	"""

	def _objective_function_gradient(self, x):
		pass

	def _objective_function(self, x):
		pass

	def __init__(
			self,
			end2end_model: CClassifierEnd2EndMalware,
			indexes_to_perturb: list,
			iterations: int = 100,
			is_debug: bool = False,
			random_init: bool = False,
			threshold: float = 0.5,
			penalty_regularizer: float = 0,
			store_checkpoints: int = None

	):
		CAttackEvasion.__init__(
			self,
			end2end_model,
			end2end_model,
		)
		self.iterations = iterations
		self.is_debug = is_debug
		self.indexes_to_perturb = indexes_to_perturb
		self.confidences_ = []
		self.changes_per_iterations_ = []

		self.random_init = random_init

		self.embedding_size = end2end_model.get_embedding_size()
		self.max_input_length = end2end_model.get_input_max_length()
		self.invalid_pos = end2end_model.get_embedding_value()
		self.embedding_value = end2end_model.get_embedding_value()
		self.shift_values = end2end_model.get_is_shifting_values()

		self._invalid_value = torch.tensor([np.infty])

		self.threshold = threshold
		self.penalty_regularizer = penalty_regularizer

		self.store_checkpoints = store_checkpoints

	def f_eval(self):
		pass

	def grad_eval(self):
		pass

	def objective_function_gradient(self, x):
		pass

	def objective_function(self, x):
		pass

	@abstractmethod
	def loss_function_gradient(self, original_x: CArray, adv_x: CArray, penalty_term: torch.Tensor):
		raise NotImplementedError("This is only an abstract method")

	def _run(self, x0, y0, x_init=None):
		"""
		Tries to achieve evasion against and end-to-end classifier.

		Parameters
		----------
		x0 : CArray
			Initial sample.
		y0 : int or CArray
			The true label of x0.
		x_init : CArray or None, optional
			Initialization point. If None, it is set to x0.
		Returns
		-------
		x_opt : CArray
			Evasion sample
		f_opt : float
			Value of objective function on x_opt.
		Notes
		-----
		Internally, it stores the confidences at each round.
		"""

		if x_init is None:
			x_init = copy.copy(x0)

		_, current_conf = self.classifier.predict(x_init, return_decision_function=True)
		current_conf = current_conf[1].item()
		self.confidences_ = [current_conf]

		if self.is_debug:
			print(f'> Original Confidence: {current_conf}')

		if use_cuda:
			self._invalid_value = self._invalid_value.cuda()
		elif use_mps:
			self._invalid_value = self._invalid_value.to(torch.device('mps'))

		E = self._get_embedded_byte_matrix()
		if self.random_init:
			x_init = self._randomize_values_for_attack(x_init)

		if self.is_debug:
			print("> Beginning new sample evasion...")

		index_to_consider = np.array(self.indexes_to_perturb)
		x_init = self.apply_feature_mapping(x_init)
		for t in range(self.iterations):
			print(t, current_conf)
			if current_conf < self.threshold:
				if self.is_debug:
					print(f"Stopped at confidence below threshold: {current_conf}/{self.threshold}")
				break

			penalty_term = self.compute_penalty_term(x0, x_init, self.penalty_regularizer)
			#print("Penalty term: ", penalty_term)
			gradient_f = self.loss_function_gradient(x0, x_init, penalty_term)
			#print("Gradient f: ", gradient_f)
			x_init = self.optimization_solver(E, gradient_f, index_to_consider, x_init)
			#print("x_init: ", x_init)
			current_conf = self.infer_step(x_init)
			#print("Current conf: ", current_conf)
			if self.store_checkpoints:
				if not t % self.store_checkpoints:
					x_temp = self.invert_feature_mapping(x0, x_init)
					_, check_conf = self.classifier.predict(x_temp, return_decision_function=True)
					check_conf = check_conf[1].item()
					if self.is_debug:
						print(f">{t}/{self.iterations} storing checkpoint: {check_conf}")
					self.confidences_.append(check_conf)
			else:
				self.confidences_.append(current_conf)
			if self.is_debug:
				print(f">{t}/{self.iterations} Shifted confidence:\t{current_conf}")

		x_init = self.invert_feature_mapping(x0, x_init)
		_, current_conf = self.classifier.predict(x_init, return_decision_function=True)
		current_conf = current_conf[1].item()
		if self.is_debug:
			print(f'>AFTER INVERSION, CONFIDENCE SCORE: {current_conf}')
		return x_init, current_conf

	def _get_embedded_byte_matrix(self):
		return self.classifier.embed(
			np.array([[i if i < 256 else 256 for i in range(self.max_input_length)]]),
			transpose=False,
		)[0][:257]

	@abstractmethod
	def apply_feature_mapping(self, x) -> CArray:
		raise NotImplementedError("This method is abstract, you should implement it somewhere else!")

	@abstractmethod
	def invert_feature_mapping(self, x, x_adv) -> CArray:
		raise NotImplementedError("This method is abstract, you should implement it somewhere else!")

	@abstractmethod
	def infer_step(self, x_init) -> CArray:
		raise NotImplementedError("This method is abstract, you should implement it somewhere else!")

	def _randomize_values_for_attack(self, x_init):
		min_val = 0 + self.shift_values
		max_val = 255 + self.shift_values
		x_init[self.indexes_to_perturb] = CDense(
			torch.randint(
				low=min_val, high=max_val, size=(1, len(self.indexes_to_perturb))
			)
		)
		return x_init

	@abstractmethod
	def compute_penalty_term(self, original_x: CArray, adv_x: CArray, par: float) -> float:
		raise NotImplementedError("This method is abstract, you should implement it somewhere else!")

	@abstractmethod
	def optimization_solver(self, E, gradient_f, index_to_consider, x_init) -> CArray:
		raise NotImplementedError("This method is abstract, you should implement it somewhere else!")

	def create_real_sample_from_adv(self, original_file_path: str, x_adv: CArray,
									new_file_path: str = None) -> bytearray:
		"""
		Create a real adversarial example

		Parameters
		----------
		original_file_path : str
			the original malware sample
		x_adv : CArray
			the perturbed malware sample, as created by the optimizer
		new_file_path : str, optional, default None
			the path where to save the adversarial malware. Leave None to not save the result to disk

		Returns
		-------
		bytearray
			the adversarial malware, as string of bytes
		"""
		with open(original_file_path, 'rb') as f:
			code = bytearray(f.read())
		padding_index = x_adv.find(x_adv == self.classifier.get_embedding_value())
		padded_x_adv = copy.copy(x_adv)
		if padding_index:
			padded_x_adv = padded_x_adv[0, :padding_index[0]]
		if self.shift_values:
			padded_x_adv = padded_x_adv - 1
		padded_x_adv = padded_x_adv.astype(np.uint8).flatten().tolist()
		padded_x_adv = b''.join([bytes([i]) for i in padded_x_adv])
		code[:len(padded_x_adv)] = padded_x_adv
		if new_file_path:
			with open(new_file_path, 'wb') as f:
				f.write(code)
		return code
