from abc import abstractmethod

from secml.array import CArray
from secml.ml.classifiers import CClassifier

from secml_malware.models import CClassifierEnd2EndMalware
#from secml_malware.models import CClassifierEnd2EndMalware, CClassifierEmber
#from secml_malware.models.c_classifier_sorel_net import CClassifierSorel


class CWrapperPhi:
	"""
	Abstract class that encapsulates a model for being used in a black-box way.
	"""

	def __init__(self, model: CClassifier):
		"""
		Creates the wrapper.

		Parameters
		----------
		model : CClassifier
		The model to wrap
		"""
		self.classifier = model

	@abstractmethod
	def extract_features(self, x: CArray):
		"""
		It maps the input sample inside the feature space of the wrapped model.

		Parameters
		----------
		x : CArray
			The sample in the input space.
		Returns
		-------
		CArray
			The feature space representation of the input sample.
		"""
		raise NotImplementedError("This method is abstract, you should implement it somewhere else!")

	def predict(self, x: CArray, return_decision_function: bool = True):
		"""
		Returns the prediction of the sample (in input space).

		Parameters
		----------
		x : CArray
			The input sample in input space.
		return_decision_function : bool, default True
			If True, it also returns the decision function value, rather than only the label.
			Default is True.
		Returns
		-------
		CArray, (CArray)
			Returns the label of the sample.
			If return_decision_function is True, it also returns the output of the decision function.
		"""
		x = x.atleast_2d()
		# feature_vectors = []
		# for i in range(x.shape[0]):
		# 	x_i = x[i, :]
		# 	padding_position = x_i.find(x_i == 256)
		# 	if padding_position:
		# 		x_i = x_i[0, :padding_position[0]]
		# 	feature_vectors.append(self.extract_features(x_i))
		# feature_vectors = CArray(feature_vectors)
		feature_vectors = self.extract_features(x)
		return self.classifier.predict(feature_vectors, return_decision_function=return_decision_function)

class CEnd2EndWrapperPhi(CWrapperPhi):
	"""
	Class that wraps an end-to-end model
	"""

	def __init__(self, model: CClassifierEnd2EndMalware):
		"""
		Creates the wrapper of a CClassifierEnd2EndMalware.

		Parameters
		----------
		model : CClassifierEnd2EndMalware
		The end to end model to wrap
		"""
		if not isinstance(model, CClassifierEnd2EndMalware):
			raise ValueError(f"Input model is {type(model)} and not CClassifierEnd2EndMalware")
		super().__init__(model)

	def extract_features(self, x):
		"""
		Crops and pads the input sample for being passed to the network.

		Parameters
		----------
		x : CArray
			The sample in the input space.
		Returns
		-------
		CArray
			The feature space representation of the input sample.
		"""
		clf: CClassifierEnd2EndMalware = self.classifier
		x = x.atleast_2d()
		padded_x = CArray.zeros((x.shape[0], clf.get_input_max_length())) + clf.get_embedding_value()
		for i in range(x.shape[0]):
			x_i = x[i, :]
			length = min(x_i.shape[-1], clf.get_input_max_length())
			padded_x[i, :length] = x_i[0, :length] + clf.get_is_shifting_values()
		return padded_x


# class CEmberWrapperPhi(CWrapperPhi):
# 	"""
# 	Class that wraps a GBDT classifier with EMBER feature set.
# 	"""

# 	def __init__(self, model: CClassifierEmber):
# 		"""
# 		Creates the wrapper of a CClassifierEmber.
#
# 		Parameters
# 		----------
# 		model : CClassifierEmber
# 		The GBDT model to wrap
# 		"""
# 		if not isinstance(model, CClassifierEmber):
# 			raise ValueError(f"Input model is {type(model)} and not CClassifierEmber")
# 		super().__init__(model)
#
# 	def extract_features(self, x):
# 		"""
# 		It extracts the EMBER hand-crafted features
#
# 		Parameters
# 		----------
# 		x : CArray
# 			The sample in the input space.
# 		Returns
# 		-------
# 		CArray
# 			The feature space representation of the input sample.
# 		"""
# 		x = x.atleast_2d()
# 		clf: CClassifierEmber = self.classifier
# 		feature_vectors = CArray.zeros((x.shape[0], 2381))
# 		for i in range(x.shape[0]):
# 			x_i = x[i, :]
# 			padding_positions = x_i.find(x_i == 256)
# 			if padding_positions:
# 				feature_vectors[i, :] = clf.extract_features(x_i[0, :padding_positions[0]])
# 			else:
# 				feature_vectors[i, :] = clf.extract_features(x_i)
# 		return feature_vectors
#
#
# class CSorelWrapperPhi(CWrapperPhi):
# 	"""
# 	Class that wraps a deep neural network with SOREL feature set.
# 	"""
#
# 	def __init__(self, model: CClassifierSorel):
# 		"""
# 		Creates the wrapper of a CClassifierSorel.
#
# 		Parameters
# 		----------
# 		model : CClassifierSorel
# 		The deep neural network model to wrap
# 		"""
# 		if not isinstance(model, CClassifierSorel):
# 			raise ValueError(f"Input model is {type(model)} and not CClassifierEmber")
# 		super().__init__(model)
#
# 	def extract_features(self, x):
# 		"""
# 		It extracts the EMBER hand-crafted features
#
# 		Parameters
# 		----------
# 		x : CArray
# 			The sample in the input space.
# 		Returns
# 		-------
# 		CArray
# 			The feature space representation of the input sample.
# 		"""
# 		x = x.atleast_2d()
# 		clf: CClassifierSorel = self.classifier
# 		feature_vectors = CArray.zeros((x.shape[0], 2381))
# 		for i in range(x.shape[0]):
# 			x_i = x[i, :]
# 			padding_positions = x_i.find(x_i == 256)
# 			if padding_positions:
# 				feature_vectors[i, :] = clf.extract_features(x_i[0, :padding_positions[0]])
# 			else:
# 				feature_vectors[i, :] = clf.extract_features(x_i)
# 		return feature_vectors
#
# 	def predict(self, x: CArray, return_decision_function: bool = True):
# 		result = super(CSorelWrapperPhi, self).predict(x, return_decision_function=return_decision_function)
# 		return result


