import os
import sys

import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from secml.array import CArray
#from secml.ml.classifiers.pytorch.c_classifier_pytorch import CClassifierPyTorch
from secml_malware.models.c_classifier_pytorch import CClassifierPyTorch
from secml.settings import SECML_PYTORCH_USE_CUDA

from secml_malware.models.basee2e import End2EndModel

use_cuda = torch.cuda.is_available() and SECML_PYTORCH_USE_CUDA
use_mps = torch.backends.mps.is_available()


class CClassifierEnd2EndMalware(CClassifierPyTorch):

	def __init__(
			self,
			model: End2EndModel,
			epochs=100,
			batch_size=256,
			train_transform=None,
			preprocess=None,
			softmax_outputs=False,
			random_state=None,
			plus_version=False,
			input_shape=(1, 2 ** 20),
			verbose=0,

	):
		super(CClassifierEnd2EndMalware, self).__init__(
			model,
			#loss="binary_crossentropy",
			loss=F.binary_cross_entropy,
			#loss =  torch.nn.BCEWithLogitsLoss,
			epochs=epochs,
			batch_size=batch_size,
			preprocess=preprocess,
			input_shape=(1, model.max_input_size),
			softmax_outputs=softmax_outputs,
			random_state=random_state,
			optimizer=torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, nesterov=True, weight_decay=1e-3),
			#optimizer_scheduler=torch.optim.lr_scheduler.StepLR(step_size=30, gamma=0.1),
		)
		self.plus_version = plus_version
		self.verbose = verbose
		self.train_transform = (
			train_transform
			if train_transform is not None
			else transforms.Lambda(lambda p: p.reshape(input_shape[1]))
		)

	def gradient(self, x, w=None):
		"""Compute gradient at x by doing a forward and a backward pass.

		The gradient is pre-multiplied by w.

		"""
		return self._gradient_f(x)

	def gradient_f_x(self, x, **kwargs):
		"""Returns the gradient of the function on point x.
		
		Arguments:
			x {CArray} -- The point
		
		Raises:
			NotImplementedError: Model do not support gradient
		
		Returns:
			CArray -- the gradient computed on x
		"""
		if self.preprocess is not None:
			# Normalize data before compute the classifier gradient
			x_pre = self.preprocess.normalize(x)
		else:  # Data will not be preprocessed
			x_pre = x
		try:  # Get the derivative of decision_function
			grad_f = self._gradient_f(x_pre, **kwargs)
		except NotImplementedError:
			raise NotImplementedError(
				"{:} does not implement `gradient_f_x`".format(self.__class__.__name__)
			)
		return grad_f

	def _gradient_f(self, x, y=None, w=None, layer=None, sum_embedding=True):
		penalty_term = torch.zeros(1)
		penalty_term.requires_grad_()
		gradient = self.compute_embedding_gradient(x.tondarray(), penalty_term)
		if sum_embedding:
			gradient = torch.mean(gradient, dim=1)
		if gradient.is_cuda:
			gradient = gradient.cpu()
		return CArray(gradient)

	def load_pretrained_model(self, path: str = None):
		"""
		Load pretrained model weights

		Parameters
		----------
		path : str, optional, default None
			The path of the model, default is None, and it will load the internal default one
		"""
		root = os.path.dirname(
			os.path.dirname(os.path.abspath(sys.modules["secml_malware"].__file__))
		)
		self._model.load_simplified_model(
			os.path.join(root, "secml_malware/data/trained/pretrained_malconv.pth") if path is None else path)
		self._classes = np.array([0, 1])
		self._n_features = 2 ** 20

	def get_embedding_size(self):
		"""
		Get the embedding space dimensionality

		Returns
		-------
		int
			the dimensionality of the embedding space
		"""
		return self._model.embedding_size

	def get_input_max_length(self):
		"""
		Get the input window length

		Returns
		-------
		int
			the window input length
		"""
		return self._model.max_input_size

	def get_embedding_value(self):
		"""
		Get the value used as padding

		Returns
		-------
		int
			a value that is used for padding the sample
		"""
		return self._model.embedding_value

	def get_is_shifting_values(self):
		"""
		Get if the model shifts the values by one

		Returns
		-------
		bool
			return if the values are shifted by one
		"""
		return self._model.shift_values

	def embed(self, x: CArray, transpose: bool = True):
		"""
		Embed the sample inside the embedding space

		Parameters
		----------
		x : CArray
			the sample to embed
		transpose : bool, optional, default True
			set True to return the transposed feature space vector
		Returns
		-------
		torch.Tensor
			the embedded vector
		"""
		return self._model.embed(x, transpose=transpose)

	def compute_embedding_gradient(self, x: CArray, penalty_term: torch.Tensor):
		"""Compute the gradient w.r.t. embedding layer.
		
		Parameters
		----------
		x : CArray
			point where gradient will be computed
		penalty_term : float
			the penalty term
		
		Returns
		----------
		CArray
			the gradient w.r.t. the embedding
		"""
		data = x
		if isinstance(x, CArray):
			data = x.tondarray()
		emb_x = self.embed(data)
		y = self.model.embedd_and_forward(emb_x)
		output = y + penalty_term
		if use_cuda:
			output = output.cuda()
		if use_mps:
			#output = torch.tensor(output, device='mps')
			output = output.to(torch.device('mps'))
		g = torch.autograd.grad(output, emb_x)[0]
		g = torch.transpose(g, 1, 2)[0]
		return g

	def embedding_predict(self, x):
		"""
		Embed the sample and produce prediction.

		Parameters
		----------
		x : CArray
			the input sample

		Returns
		-------
		float
			the malware score
		"""
		return self._model.embedd_and_forward(x)

	def _forward(self, x):
		x = x.atleast_2d()
		scores = super(CClassifierEnd2EndMalware, self)._forward(x)
		confidence = []
		for i in range(x.shape[0]):
			confidence.append([1 - scores[i, 1].item(), scores[i, 1].item()])
		confidence = CArray(confidence)
		return confidence
