from secml_malware.attack.blackbox.c_black_box_format_exploit_evasion import CBlackBoxFormatExploitEvasionProblem
from secml_malware.attack.blackbox.c_black_box_padding_evasion import CBlackBoxPaddingEvasionProblem
from secml_malware.attack.blackbox.c_blackbox_header_problem import CBlackBoxHeaderEvasionProblem
from secml_malware.attack.blackbox.c_blackbox_headerfields_problem import CBlackBoxHeaderFieldsEvasionProblem
from secml_malware.attack.blackbox.c_gamma_evasion import CGammaEvasionProblem
from secml_malware.attack.blackbox.c_gamma_sections_evasion import CGammaSectionsEvasionProblem
from secml_malware.attack.blackbox.c_wrapper_phi import CEmberWrapperPhi, CEnd2EndWrapperPhi, CSorelWrapperPhi
from secml_malware.attack.blackbox.ga.c_base_genetic_engine import CGeneticAlgorithm
from secml_malware.attack.blackbox.tests.black_box_base_test import BlackBoxBaseTests


class BlackBoxEvasionAttackTestSuite(BlackBoxBaseTests):

	def test_blackbox_format_exploit_ember(self):
		black_box_attack = CBlackBoxFormatExploitEvasionProblem(
			CEmberWrapperPhi(self.ember_classifier),
			preferable_extension_amount=0x200,
			pe_header_extension=0,
			iterations=5,
			population_size=10
		)
		engine = CGeneticAlgorithm(black_box_attack)
		y_pred, _, _, _ = engine.run(self.X, self.Y)
		self.assert_evasion_result(y_pred)

	def test_blackbox_format_exploit_malconv(self):
		black_box_attack = CBlackBoxFormatExploitEvasionProblem(
			CEnd2EndWrapperPhi(self.end2end_classifier),
			preferable_extension_amount=0x200,
			pe_header_extension=0,
			iterations=5,
			population_size=10
		)
		engine = CGeneticAlgorithm(black_box_attack)
		y_pred, _, _, _ = engine.run(self.X, self.Y)
		self.assert_evasion_result(y_pred)

	def test_blackbox_partial_dos_header_malconv(self):
		black_box_attack = CBlackBoxHeaderEvasionProblem(
			CEnd2EndWrapperPhi(self.end2end_classifier),
			optimize_all_dos=False,
			iterations=5,
			population_size=10
		)
		engine = CGeneticAlgorithm(black_box_attack)
		y_pred, _, _, _ = engine.run(self.X, self.Y)
		self.assert_evasion_result(y_pred)

	def test_blackbox_padding_malconv(self):
		black_box_attack = CBlackBoxPaddingEvasionProblem(
			CEnd2EndWrapperPhi(self.end2end_classifier),
			how_many_padding_bytes=1024,
			iterations=5,
			population_size=10
		)
		engine = CGeneticAlgorithm(black_box_attack)
		y_pred, _, _, _ = engine.run(self.X, self.Y)
		self.assert_evasion_result(y_pred)

	def test_blackbox_dos_header_malconv(self):
		black_box_attack = CBlackBoxHeaderEvasionProblem(
			CEnd2EndWrapperPhi(self.end2end_classifier),
			optimize_all_dos=True,
			iterations=2,
			population_size=2
		)
		engine = CGeneticAlgorithm(black_box_attack)
		y_pred, _, _, _ = engine.run(self.X, self.Y)
		self.assert_evasion_result(y_pred)

	def test_blackbox_header_fields_malconv(self):
		black_box_attack = CBlackBoxHeaderFieldsEvasionProblem(
			CEnd2EndWrapperPhi(self.end2end_classifier),
			iterations=2,
			population_size=2
		)
		engine = CGeneticAlgorithm(black_box_attack)
		y_pred, _, _, _ = engine.run(self.X, self.Y)
		self.assert_evasion_result(y_pred)

	def test_blackbox_gamma_malconv(self):
		section_population, _ = CGammaEvasionProblem.create_section_population_from_folder(self.goodware_folder, 10)
		gamma = CGammaEvasionProblem(section_population, CEnd2EndWrapperPhi(self.end2end_classifier),
									 population_size=10, penalty_regularizer=1e-6, iterations=5)
		engine = CGeneticAlgorithm(gamma)
		y_pred, _, _, _ = engine.run(self.X, self.Y)
		self.assert_evasion_result(y_pred)

	def test_blackbox_gamma_ember(self):
		section_population, _ = CGammaEvasionProblem.create_section_population_from_folder(self.goodware_folder, 100)
		gamma = CGammaEvasionProblem(section_population, CEmberWrapperPhi(self.ember_classifier),
									 population_size=10, penalty_regularizer=1e-6, iterations=5)
		engine = CGeneticAlgorithm(gamma)
		y_pred, _, _, _ = engine.run(self.X, self.Y)
		self.assert_evasion_result(y_pred)

	def test_blackbox_sections_gamma_ember(self):
		section_population, _ = CGammaSectionsEvasionProblem.create_section_population_from_folder(self.goodware_folder,
																								   100)
		gamma = CGammaSectionsEvasionProblem(section_population, CEmberWrapperPhi(self.ember_classifier),
											 population_size=3, penalty_regularizer=1e-6, iterations=4)
		engine = CGeneticAlgorithm(gamma)
		y_pred, _, _, _ = engine.run(self.X, self.Y)
		self.assert_evasion_result(y_pred)

	def test_sorel_dnn_wrapper_batch_prediction(self):
		wrapper = CSorelWrapperPhi(self.sorel_net_classifier)
		self._batch_prediction_of_wrapper(wrapper)

	def test_blackbox_sections_gamma_sorel(self):
		section_population, _ = CGammaSectionsEvasionProblem.create_section_population_from_folder(self.goodware_folder,
																								   100)
		gamma = CGammaSectionsEvasionProblem(section_population, CSorelWrapperPhi(self.sorel_net_classifier),
											 population_size=3, penalty_regularizer=1e-6, iterations=4)
		engine = CGeneticAlgorithm(gamma)
		y_pred, _, _, _ = engine.run(self.X, self.Y)
		self.assert_evasion_result(y_pred)

	def test_gbdt_wrapper_batch_prediction(self):
		wrapper = CEmberWrapperPhi(self.ember_classifier)
		self._batch_prediction_of_wrapper(wrapper)

	def test_e2e_wrapper_batch_prediction(self):
		wrapper = CEnd2EndWrapperPhi(self.end2end_classifier)
		self._batch_prediction_of_wrapper(wrapper)

	def _batch_prediction_of_wrapper(self, wrapper):
		y_pred = wrapper.predict(self.X, return_decision_function=False)
		y_pred1, y_pred2 = wrapper.predict(self.X[0, :], return_decision_function=False), wrapper.predict(self.X[1, :],
																										  return_decision_function=False)
		self.assert_array_equal(y_pred1, y_pred[0])
		self.assert_array_equal(y_pred2, y_pred[1])

	def assert_evasion_result(self, y_pred):
		n_old_y_malw = sum(self.Y == 1)
		n_false_negative = sum(self.Y == 0)
		n_new_detected_malw = sum(y_pred == 1) - n_false_negative
		self.assertNotEqual(
			n_old_y_malw,
			n_new_detected_malw,
			msg="Evasion achieved: {}/{}".format(
				self.Y.shape[0] - n_new_detected_malw, self.Y.shape[0]
			),
		)
