import unittest

from secml.array import CArray

from secml_malware.attack.whitebox import CHeaderEvasion, CPaddingEvasion, CKreukEvasion, CSuciuEvasion, \
	CContentShiftingEvasion
from secml_malware.attack.whitebox.c_format_exploit_evasion import CFormatExploitEvasion
from secml_malware.attack.whitebox.c_headerfields_evasion import CHeaderFieldsEvasion
from secml_malware.attack.whitebox.tests.malware_test_base import End2EndBaseTests


class EvasionEnd2EndTestSuite(End2EndBaseTests):
	def setUp(self):
		super(EvasionEnd2EndTestSuite, self).setUp()
		self.classifier.load_pretrained_model(self.ember_path)
		self.surrogate_classifier.load_pretrained_model(self.surrogate_path)
		self.Y = self.classifier.predict(CArray(self.X), return_decision_function=False)

	def test_whitebox_ember_model_attack_random_init(self):
		attack = CHeaderEvasion(
			self.classifier,
			is_debug=True,
			random_init=True,
		)
		self.assert_evasion_result(attack)

	def test_whitebox_ember_model_attack_no_random_init(self):
		attack = CHeaderEvasion(
			self.classifier,
			is_debug=True,
			random_init=False,
		)
		y_pred, scores, _, _ = attack.run(self.X, self.Y)
		n_old_y_malw = sum(self.Y == 1)
		n_false_negative = sum(self.Y == 0)
		n_new_detected_malw = sum(y_pred == 1) - n_false_negative
		self.assertNotEqual(
			n_old_y_malw,
			n_new_detected_malw,
			msg="Evasion achieved: {}/{}".format(
				self.Y.shape[0] - n_new_detected_malw, self.Y.shape[0]
			),
		)

	def test_padding_whitebox_ember_model_attack(self):
		padding_attack = CPaddingEvasion(
			self.classifier,
			1000,
			random_init=True,
			is_debug=True
		)
		self.assert_evasion_result(padding_attack)

	def test_kreuk_whitebox_attack(self):
		kreuk_attack = CKreukEvasion(
			self.classifier,
			how_many_padding_bytes=1000,
			epsilon=0.03,
			iterations=50,
			is_debug=True
		)
		self.assert_evasion_result(kreuk_attack)

	def test_kreuk_whitebox_no_slack_attack_p2(self):
		kreuk_attack = CKreukEvasion(
			self.classifier,
			how_many_padding_bytes=1000,
			epsilon=0.03,
			iterations=50,
			is_debug=True,
			compute_slack=False,
			p_norm=2
		)
		self.assert_evasion_result(kreuk_attack)

	def test_kreuk_whitebox_attack_p2(self):
		kreuk_attack = CKreukEvasion(
			self.classifier,
			how_many_padding_bytes=1000,
			epsilon=0.03,
			iterations=50,
			is_debug=True,
			compute_slack=True,
			p_norm=2
		)
		self.assert_evasion_result(kreuk_attack)

	def test_kreuk_whitebox_no_slack_attack(self):
		kreuk_attack = CKreukEvasion(
			self.classifier,
			how_many_padding_bytes=1000,
			epsilon=0.03,
			iterations=50,
			is_debug=True,
			compute_slack=False
		)
		self.assert_evasion_result(kreuk_attack)

	def test_suciu_appending_whitebox_attack(self):
		kreuk_attack = CSuciuEvasion(
			self.classifier,
			how_many_padding_bytes=1000,
			epsilon=0.03,
			is_debug=True
		)
		self.assert_evasion_result(kreuk_attack)

	def test_suciu_appending_whitebox_no_slack_attack(self):
		kreuk_attack = CSuciuEvasion(
			self.classifier,
			how_many_padding_bytes=1000,
			epsilon=0.03,
			is_debug=True,
			compute_slack=False
		)
		self.assert_evasion_result(kreuk_attack)

	def test_section_shift_attack(self):
		shift_attack = CContentShiftingEvasion(
			self.classifier,
			preferable_extension_amount=0x200,
			iterations=2,
			is_debug=True
		)
		self.assert_evasion_result(shift_attack)

	def test_pe_shift_attack(self):
		shift_attack = CFormatExploitEvasion(
			self.classifier,
			preferable_extension_amount=0,
			pe_header_extension=0x200,
			iterations=20,
			is_debug=True
		)
		self.assert_evasion_result(shift_attack)

	def test_header_fields_attack(self):
		header_fields_attack = CHeaderFieldsEvasion(
			self.classifier,
			iterations=20,
			is_debug=True
		)
		self.assert_evasion_result(header_fields_attack)

	def assert_evasion_result(self, attack, create_adv_malware=False):
		y_pred, _, adv_ds, _ = attack.run(self.X, self.Y)
		n_old_y_malw = sum(self.Y == 1)
		n_false_negative = sum(self.Y == 0)
		n_new_detected_malw = sum(y_pred == 1) - n_false_negative
		if create_adv_malware:
			for i, f in enumerate(self.complete_malware_paths):
				attack.create_real_sample_from_adv(f, adv_ds.X[i, :], f + '.adv')
		self.assertNotEqual(
			n_old_y_malw,
			n_new_detected_malw,
			msg="Evasion achieved: {}/{}".format(
				self.Y.shape[0] - n_new_detected_malw, self.Y.shape[0]
			),
		)


if __name__ == "__main__":
	unittest.main()
