import torch.nn as nn 
import torch 
from torch.nn.functional import mse_loss  
import torch.nn.functional as F  

def is_psd(mat):
	return bool(torch.all(torch.eig(mat)[0][:,0] >=0))

class custom_loss(torch.autograd.Function):
	@staticmethod  
	def forward(ctx, kernel, kernel_mask, num_cw):
		'''kenrel shape : [batch, num_vec, num_vec]'''
		batch, num_vec = kernel.size(0), kernel.size(1) 
		I_A = torch.zeros_like(kernel).to(kernel.device)
		for i in range(num_cw):
			I_A[:,i,i] = 1.0 
		I_A = I_A.masked_fill(kernel_mask, 1e-8)
		kernel = kernel.masked_fill(kernel_mask,0)   
		B = torch.inverse(kernel+I_A)   
		B = B.masked_fill(kernel_mask,0)
		ctx.save_for_backward(B)         # (L+I_A)^{-1} , number of contexts
		ctx.cw = num_cw 
		I_ = torch.eye(num_vec-num_cw).unsqueeze(0).expand(batch,-1,-1).to(device=B.device, dtype=B.dtype) 
		I_.requires_grad = True 
		C = I_ - B[:,num_cw:,num_cw:]  
		trace = -C.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1)
		loss =  torch.sum(trace) / len(trace) 
		return loss  

	@staticmethod
	def backward(ctx, grad_output):
		B, = ctx.saved_tensors 
		num_cw = ctx.cw  
		num_batch, num_vec = B.size(0), B.size(-1) 
		E_ = torch.zeros_like(B).to(B.device)
		E_[:,num_cw:,num_cw:] = torch.eye(num_vec-num_cw).unsqueeze(0).repeat(num_batch,1,1).to(B.device, dtype=B.dtype)
		output_ = torch.bmm(B, E_)
		output = -torch.bmm(output_, B).transpose(-1,-2) 
		return grad_output * output , None , None

class diversity_loss(nn.Module):
	def forward(self, kernel, kernel_mask, num_cw=2):
		return custom_loss.apply(kernel, kernel_mask, num_cw)   

class ref_diversity_loss(nn.Module):
	def forward(self, kernel, kernel_mask, num_cw=2):
		'''kenrel shape : [batch, num_vec, num_vec]'''
		batch, num_vec = kernel.size(0), kernel.size(1)
		I_A = torch.eye(num_vec-num_cw).unsqueeze(0).expand(batch,-1,-1).to(kernel.device) 
		I_A = F.pad(I_A, (num_cw,0,num_cw,0))
		I_A = I_A.masked_fill(kernel_mask, 1e-8)
		kernel = kernel.masked_fill(kernel_mask,0)
		kernel_ = kernel + I_A
		B = torch.inverse(kernel_) 
		I_ = torch.eye(num_vec-num_cw).unsqueeze(0).expand(batch,-1,-1).to(device=B.device, dtype=B.dtype) 
		I_.requires_grad = True 
		C = I_ - B[:,num_cw:,num_cw:]  
		trace = -C.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1)
		loss =  torch.sum(trace) / len(trace) 
		return loss 