"""
layer and loss definition for RSC .
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
import pdb

epsilon = 1e-12

class RSC(nn.Module):
	def __init__(self, n_input_features, n_classes, n_rank=None, bias=False,use_shift_vectors=False):
		super(RSC, self).__init__()

		self.n_input_features = n_input_features
		#self.output_features = output_features
		#pdb.set_trace()
		#assert(False)
		# nn.Parameter is a special kind of Tensor, that will get
		# automatically registered as Module's parameter once it's assigned
		# as an attribute. Parameters and buffers need to be registered, or
		# they won't appear in .parameters() (doesn't apply to buffers), and
		# won't be converted when e.g. .cuda() is called. You can use
		# .register_buffer() to register buffers.
		# nn.Parameters require gradients by default.

		self.n_rank = n_rank or self.n_input_features

		print("RSC rank: ", self.n_rank)

		self.n_classes = n_classes
		self.use_shift_vectors = use_shift_vectors
		if self.use_shift_vectors:
			self.shift_vectors = nn.Parameter(torch.Tensor(self.n_classes,1,self.n_input_features))
			self.shift_vectors.data.uniform_(-1., 1.)
		else:
			#pdb.set_trace()
			self.shift_vectors = torch.zeros((self.n_classes,1,self.n_input_features),requires_grad=False)

		
		#P are random projection matrices to project the shift vectors into an R^n_classes sub-space
		self.P = torch.randn((n_classes, n_input_features), requires_grad=False)

		#comment out above and uncomment below to disable shift vectors
		#self.shift_vectors = torch.zeros((self.n_classes,1,self.n_input_features),requires_grad=False)

		#self.B_s = nn.Parameter(torch.Tensor(self.n_classes,self.n_input_features,self.n_input_features))
		self.B_s = nn.Parameter(torch.Tensor(self.n_classes,self.n_input_features,self.n_rank))
		self.register_parameter('bias', None)

		# perhaps not the best way to initialize weights
		self.B_s.data.uniform_(-.1,.1)


	def forward(self, x):
		x_in  = x
		#pdb.set_trace()
		# x is a tensor consisting of activation per class
		#n_classes = self.n_classes

		# the operation below ensures that each matrix in bbt is PSD
		bbt = torch.bmm(self.B_s,self.B_s.permute(0,2,1))

		# (x - self.shift_vectors) gives a (KxNxK) tensor where K is the number of classes. 
		# the matrix in the outer dimension m is the result of subtracting the x's from the 
		#shift vector of the m'th class 


		if x.device.type == 'cuda':
			self.shift_vectors = self.shift_vectors.to(x.device)
		# else:
		# 	shift_vectors = torch.zeros((self.n_classes,1,self.input_features),requires_grad=False)

		bbt_xs = torch.bmm(bbt, (x-self.shift_vectors).permute(0,2,1))
		bbt_xs = bbt_xs.permute(0,2,1)
		xs = (x-self.shift_vectors)*bbt_xs
		#bbt_xs = torch.bmm(bbt, (x-shift_vectors).permute(0,2,1))
		#xs = x*bbt_xs


		#sum across columns  (i.e for a given depth and row, corresponding to class k and input i,
		# to get the "affinity" that input i has for class k) 
		xs = torch.sum(xs,dim=2) 
		#pdb.set_trace()

		# add very small epsilon values in case the input vector x is all 0.s (this happens occasionally)
		xs = xs + torch.tensor(1e-16, requires_grad=False).to(xs.device)

		#in the above, we lose a dimension, with class affinities for an input spraid across the 0th
		#dimension. So we normalize across the 0'th dimension.
		#and then transpose since  probabilities are expected to be along dimension 1.
		#p = 1, since we are normalizing on L-1 norm.
		xs = F.normalize(xs,dim=0,p=1).permute(1,0)
		#pdb.set_trace()

		# if torch.min(xs) < 0.:
		# 	pdb.set_trace()

		
		#xs = torch.log(xs) # log probabilities. will be passed to NLL loss.
		if self.use_shift_vectors:
			return [xs,self.shift_vectors, self.B_s, self.P, x_in]
		else:
			return [xs,None,self.B_s, self.P, x_in]





class rsc_loss(_Loss):
	def __init__(self,model,kappa=0.):
		print("using RSC loss function\n")
		super(rsc_loss,self).__init__()
		self.model = model
		self.kappa = kappa
		self.d_epsilon = torch.tensor(1e-15)

		#pdb.set_trace()

	def __call__(self, input_batch, target_batch, epoch):
		"""
		input_batch is the output of the NN
		target_batch are the labels
		"""
		#pdb.set_trace()
		p_out  = input_batch[0]
		shift_vectors = input_batch[1]
		B_s = input_batch[2]
		proj_matrix = input_batch[3]
		x_in = input_batch[4]

		#entropy:
		h_p = (-p_out * torch.log(p_out)).sum(1).mean()

		#shift vector determinant
		if not shift_vectors is None:
			#assert(False)
			d,r,c = shift_vectors.shape
			# #sv_matrix = shift_vectors.view(-1,c)
			sv_matrix = shift_vectors.view(-1,d)

			#shift vector projections
			if sv_matrix.device.type =='cuda':
				proj_matrix = proj_matrix.to(device = sv_matrix.device)

			sv_matrix = torch.matmul(proj_matrix, sv_matrix)


			#sv_matrix_normalized  = sv_matrix - sv_matrix.mean(1) #subtract mean column vector
			sv_matrix_normalized  = sv_matrix - sv_matrix.mean(1).view(-1,1) #subtract mean column vector
			sv_cov = torch.matmul(sv_matrix_normalized, sv_matrix_normalized.permute(1,0)) #covariance of normalized shift vectors
			#sv_det = torch.det(shift_vectors.view(-1,c))
			sv_cov_det = torch.det(sv_cov)
			
			sv_cov_log_det = torch.logdet(sv_cov)


		#c_loss is the classification loss
		c_loss = F.nll_loss(torch.log(p_out), target_batch)
		
		# if c_loss.data.cpu().numpy() == torch.tensor(float("Inf")).numpy():
		# 	pdb.set_trace()

		#d_loss is the determininant loss
		#d_loss = self.kappa * 1./(sv_det + epsilon)  #epsilon to prevent division by 0.
		
		#pdb.set_trace()
		#if sv_cov_det==torch.tensor(float("Inf")).cuda():
		#is there a better way to compare? the one above gives "variables located on different CUDA device" errors

		#d_log_loss = -self.kappa * sv_cov_log_det

		if not shift_vectors is None:
			if sv_cov_det.data.cpu().numpy() == torch.tensor(float("Inf")).numpy() or \
				sv_cov_det.data.cpu().numpy() == torch.tensor(float("-Inf")).numpy(): 
				#d_loss = torch.tensor(0., requires_grad=False).cuda()
				# this looks overly complicated for something as simple as creating a 0. tensor 
				# but we need the data to be on the GPU, and also on the same GPU as sv_cov_det, otherwise
				# pytorch complains of "arguments are located on different GPUs"
				d_loss = torch.tensor(0., requires_grad=False).to(sv_cov_det.device)
				d_log_loss = torch.tensor(0., requires_grad=False).to(sv_cov_log_det.device)
			else:
				d_loss = self.kappa * 1./(sv_cov_det + epsilon)
				d_log_loss = -self.kappa * sv_cov_log_det
				#pdb.set_trace()
				#d_loss = self.kappa * 1./torch.log(1.+sv_cov_det)
		else: # no shift vectors being used.
			#pdb.set_trace()
			d_loss = torch.tensor(0., requires_grad=False).to(c_loss.device)
			d_log_loss = torch.tensor(0., requires_grad=False).to(c_loss.device)

		#d_loss = torch.max(d_loss, self.d_epsilon.cuda())
		#pdb.set_trace()

		if self.model.training:
			#pass
			#print("\nrsc train loss details: %d,%f,%f,%f,%f\n" %(epoch,c_loss.mean(),d_loss,h_p,sv_cov_det))
			#print("\nrsc train loss details: %f,%f,%f\n" %(c_loss.mean(),h_p,torch.norm(B_s)))
			print("\nrsc train loss details: %f,%f,%f\n" %(c_loss.mean(),h_p,d_log_loss))
		else:
			#pass
			#print("\nrsc validation loss details: %f,%f,%f,%f\n" %(c_loss.mean(),d_loss,h_p,sv_cov_det))
			print("\nrsc validation loss details: %f,%f,%f\n" %(c_loss.mean(),h_p,d_log_loss))

		#return c_loss + d_loss
		#return c_loss
		return c_loss + d_log_loss










# class RSC2(nn.Module):
# 	def __init__(self, input_features, output_features, bias=False):
# 		super(RSC2, self).__init__()
# 		#pdb.set_trace()
# 		self.input_features = input_features
# 		self.output_features = output_features


# 		# nn.Parameter is a special kind of Tensor, that will get
# 		# automatically registered as Module's parameter once it's assigned
# 		# as an attribute. Parameters and buffers need to be registered, or
# 		# they won't appear in .parameters() (doesn't apply to buffers), and
# 		# won't be converted when e.g. .cuda() is called. You can use
# 		# .register_buffer() to register buffers.
# 		# nn.Parameters require gradients by default.
# 		#self.n_classes = input_features
# 		self.n_classes = self.output_features

# 		#self.shift_vectors = nn.Parameter(torch.Tensor(self.n_classes,1,self.n_classes))
# 		#self.shift_vectors = torch.zeros((self.n_classes,1,self.input_features),requires_grad=False)
# 		#self.B_s = nn.Parameter(torch.Tensor(self.n_classes,self.n_classes,self.n_classes))
# 		self.B_s = nn.Parameter(torch.Tensor(self.n_classes,self.input_features,1))

# 		self.register_parameter('bias', None)

# 		# Not a very smart way to initialize weights
# 		# self.shift_vectors.data.uniform_(-1., 1.)
# 		self.B_s.data.uniform_(-.1,.1)

# 	def forward(self, x):
# 		#pdb.set_trace()
# 		# x is a tensor consisting of feature activations
# 		n_classes = self.n_classes

# 		# the operation below ensures that each matrix in bbt is PSD
# 		bbt = torch.bmm(self.B_s,self.B_s.permute(0,2,1))

# 		# (x - self.shift_vectors) gives a (KxNxK) tensor where K is the number of classes. 
# 		# the matrix in the outer dimension m is the result of subtracting the x's from the 
# 		#shift vector of the m'th class 

# 		if x.device.type == 'cuda':
# 			shift_vectors = torch.zeros((self.n_classes,1,self.input_features),requires_grad=False).to(x.device)
# 		else:
# 			shift_vectors = torch.zeros((self.n_classes,1,self.input_features),requires_grad=False)

# 		bbt_xs = torch.bmm(bbt, (x-shift_vectors).permute(0,2,1))
# 		bbt_xs = bbt_xs.permute(0,2,1)

# 		xs = (x-shift_vectors)*bbt_xs



# 		#sum across columns  (i.e for a given depth and row, corresponding to class k and input i,
# 		# to get the "affinity" that input i has for class k) 
# 		xs = torch.sum(xs,dim=2) 
# 		#in the above, we lose a dimension, with class affinities for an input spraid across the 0th
# 		#dimension. So we normalize across the 0'th dimension.
# 		#and then transpose since  probabilities are expected to be along dimension 1.
# 		#p = 1, since we are normalizing on L-1 norm.
# 		xs = F.normalize(xs,dim=0,p=1).permute(1,0)
# 		#xs = torch.log(xs) # log probabilities. will be passed to NLL loss.
# 		#return [xs,self.shift_vectors]
# 		return [xs, self.B_s]









# class rsc_loss_2(_Loss):
# 	def __init__(self,model,kappa=0.):
# 		print("using RSC loss function 2\n")
# 		super(rsc_loss_2,self).__init__()
# 		self.model = model
# 		#self.kappa = kappa
# 		#self.d_epsilon = torch.tensor(1e-15)

# 		#pdb.set_trace()

# 	def __call__(self, input_batch, target_batch):
# 		"""
# 		input_batch is the output of the NN
# 		target_batch are the labels
# 		"""
# 		#pdb.set_trace()
# 		p_out  = input_batch[0]
# 		#shift_vectors = input_batch[1]
# 		B_s = input_batch[1]
# 		#entropy:
# 		h_p = (-p_out * torch.log(p_out)).sum(1).mean()
# 		#shift vector determinant
# 		# d,r,c = shift_vectors.shape
# 		# sv_matrix = shift_vectors.view(-1,c)
# 		# sv_matrix_normalized  = sv_matrix - sv_matrix.mean(1) #subtract mean column vector
# 		# sv_cov = torch.matmul(sv_matrix_normalized, sv_matrix_normalized.permute(1,0)) #covariance of normalized shift vectors
# 		# #sv_det = torch.det(shift_vectors.view(-1,c))
# 		# sv_cov_det = torch.det(sv_cov)
		
# 		# #c_loss is the classification loss
# 		c_loss = F.nll_loss(torch.log(p_out), target_batch)
		
# 		# #d_loss is the determininant loss
# 		# #d_loss = self.kappa * 1./(sv_det + epsilon)  #epsilon to prevent division by 0.
		
# 		# #pdb.set_trace()
# 		# #if sv_cov_det==torch.tensor(float("Inf")).cuda():
# 		# #is there a better way to compare? the one above gives "variables located on different CUDA device" errors
# 		# if sv_cov_det.data.cpu().numpy() == torch.tensor(float("Inf")).numpy():
# 		# 	#d_loss = torch.tensor(0., requires_grad=False).cuda()
# 		# 	# this looks overly complicated for something as simple as creating a 0. tensor 
# 		# 	# but we need the data to be on the GPU, and also on the same GPU as sv_cov_det, otherwise
# 		# 	# pytorch complains of "arguments are located on different GPUs"
# 		# 	d_loss = torch.tensor(0., requires_grad=False).to(sv_cov_det.device)
# 		# else:
# 		# 	d_loss = self.kappa * 1./(sv_cov_det + epsilon)
# 		# 	#pdb.set_trace()
# 		# 	#d_loss = self.kappa * 1./torch.log(1.+sv_cov_det)
		

# 		# #d_loss = torch.max(d_loss, self.d_epsilon.cuda())

# 		# if self.model.training:
# 		# 	#pass
# 		# 	print("\nrsc train loss details: %f,%f,%f,%f\n" %(c_loss.mean(),d_loss,h_p,sv_cov_det))
# 		# else:
# 		# 	#pass
# 		# 	print("\nrsc validation loss details: %f,%f,%f,%f\n" %(c_loss.mean(),d_loss,h_p,sv_cov_det))
		
# 		#return c_loss + d_loss
# 		return c_loss