from secml_malware.attack.whitebox.tests.malware_test_base import EmberBaseTests


class EmberTestSuite(EmberBaseTests):
	def test_batch_prediction(self):
		features = self.classifier.extract_features(self.X)
		y_pred = self.classifier.predict(features)

		y_pred1, y_pred2 = self.classifier.predict(features[0, :]), self.classifier.predict(features[1, :])

		self.assert_array_equal(y_pred1, y_pred[0])
		self.assert_array_equal(y_pred2, y_pred[1])
