import os
from os import listdir

import magic
import numpy as np
from secml.array import CArray
from secml.testing import CUnitTest

from secml_malware.models import CClassifierEmber
from secml_malware.models.c_classifier_end2end_malware import CClassifierEnd2EndMalware
from secml_malware.models.malconv import MalConv


class End2EndBaseTests(CUnitTest):
	def setUp(self):

		self.max_length = 2 ** 20
		self.padding_value = 256
		self.root_module_path = os.path.dirname(
			os.path.dirname(os.path.dirname(__file__))
		)
		self.classifier = CClassifierEnd2EndMalware(MalConv())
		self.malconv_plus = CClassifierEnd2EndMalware(MalConv(), plus_version=True)
		self.surrogate_classifier = CClassifierEnd2EndMalware(MalConv())
		self.ember_path = os.path.join(
			self.root_module_path, "../data/trained/pretrained_malconv.pth"
		)
		self.surrogate_path = os.path.join(
			self.root_module_path, "../data/trained/pretrained_malconv.pth"
		)
		self.malware_folder = os.path.join(
			self.root_module_path, "../data/malware_samples/test_folder"
		)
		self.goodware_folder = os.path.join(
			self.root_module_path, "../data/goodware_samples/"
		)
		self.single_malware_path = os.path.join(
			self.root_module_path, "../data/malware_samples/test_malware"
		)
		self.baseline = np.array([np.zeros(self.max_length) + self.padding_value])
		malware_file_paths = [os.path.join(self.malware_folder, f) for f in listdir(self.malware_folder)]
		self.complete_malware_paths = []
		X = []
		y = []
		for complete_path in malware_file_paths:
			if not os.path.isfile(complete_path):
				continue
			if "PE32" not in magic.from_file(complete_path): continue
			with open(complete_path, "rb") as malware:
				print(f'>Using {os.path.dirname(complete_path)}')
				self.complete_malware_paths.append(complete_path)
				code = MalConv.bytes_to_numpy(malware.read(), self.max_length, 256, False)
			X.append(code)
			y.append(1)
		# X.append(self.baseline[0])
		# y.append(0)
		self.X = CArray(X)
		self.Y = y
		with open(self.single_malware_path, "rb") as f:
			self.byte_malware = bytearray(f.read())
		self.malware = np.array(
			[MalConv.bytes_to_numpy(self.byte_malware, self.max_length, 256, False)]
		)


class EmberBaseTests(CUnitTest):
	def setUp(self):

		self.root_module_path = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))

		self.ember_path = os.path.join(self.root_module_path, "../data/trained/ember_model.txt")
		self.classifier = CClassifierEmber(tree_path=self.ember_path)

		self.malware_folder = os.path.join(self.root_module_path, "../data/malware_samples/test_folder")
		self.goodware_folder = os.path.join(self.root_module_path, "../data/goodware_samples/")

		X = []
		y = []
		for f in listdir(self.malware_folder):
			complete_path = os.path.join(self.malware_folder, f)
			if not os.path.isfile(complete_path):
				continue
			if "PE32" not in magic.from_file(complete_path): continue
			max_length = 0
			with open(complete_path, "rb") as malware:
				print(f'>Using {f}')
				# code = MalConv.bytes_to_numpy(malware.read(), self.max_length, 256, False)
				code = malware.read()
				max_length = max(max_length, len(code))
			X.append(code)
			y.append(1)

		self.X = [MalConv.bytes_to_numpy(x, max_length, 256, False) for x in X]
		self.X = CArray(self.X)
		self.Y = y
