import unittest

import torch
from secml.array import CArray

from secml_malware.attack.whitebox.tests.malware_test_base import End2EndBaseTests
from secml_malware.models import MalConv


class MalConvTestSuite(End2EndBaseTests):
	def setUp(self):
		super(MalConvTestSuite, self).setUp()
		self.m = MalConv()

	def test_empty_baseline_is_goodware(self):
		self.m.load_simplified_model(self.ember_path)
		y_pred = self.m(torch.tensor(self.baseline))
		if y_pred.is_cuda:
			y_pred = y_pred.cpu()
		y = int(y_pred.detach().numpy()[0][0] > 0.5)
		self.assertEqual(y, 0)

	def test_malconv_load_ember_no_path(self):
		self.classifier.load_pretrained_model()
		y = self.classifier.predict(
			CArray(torch.tensor(self.baseline)), return_decision_function=False
		)
		self.assertEqual(y, 0)

	def test_empty_baseline_pytorch_classifier_wrapper(self):
		self.classifier.load_pretrained_model(self.ember_path)
		y = self.classifier.predict(
			CArray(torch.tensor(self.baseline)), return_decision_function=False
		)
		self.assertEqual(y, 0)

	def test_batch_prediction_malconv(self):
		self.classifier.load_pretrained_model(self.ember_path)
		y_pred = self.classifier.predict(self.X)

		y_pred1, y_pred2 = self.classifier.predict(self.X[0, :]), self.classifier.predict(self.X[1, :])

		self.assert_array_equal(y_pred1, y_pred[0])
		self.assert_array_equal(y_pred2, y_pred[1])


if __name__ == "__main__":
	unittest.main()
