from .soft_dtw import SoftDTW  
import torch.nn as nn 
import torch 

def get_mask_from_len(x_len, y_len, max_len):
	"""shape: x_len, y_len:[B], mask.shape = [B,max(x_len), max(y_len)]"""
	B, x_max, y_max = x_len.size(0), max_len, max_len
	idxs = torch.arange(0, x_max*y_max).view(x_max, y_max).unsqueeze(0).expand(B,-1,-1).to(device=x_len.device)
	x_len = x_len.unsqueeze(-1).unsqueeze(-1).expand(-1, x_max, y_max)
	y_len = y_len.unsqueeze(-1).unsqueeze(-1).expand(-1, x_max, y_max)
	mask = torch.logical_or( (idxs // y_max) >= x_len , ((idxs%y_max) >= y_len) )

	return mask

def l1_dist(x, y, x_len, y_len):
	"""x.shape=[B,l_x,1] , x_len.shape=[B], y_len.shape[B]"""
	n = x.size(1)
	m = y.size(1)
	d = x.size(2)
	x = x.unsqueeze(2).expand(-1, n, m, d)
	y = y.unsqueeze(1).expand(-1, n, m, d)
	mask = get_mask_from_len(x_len, y_len,n)

	return torch.abs(x-y).sum(3).masked_fill(mask, 0.0) , mask    

def similarity(s1, s2, sdtw, nc_term ,x_len, y_len, sweight):
	for i in range(nc_term.size(0)):
		if nc_term[i] == 0 :
			nc_term[i] = 2  
	soft_dtw = sdtw(s1, s2, x_len, y_len) / nc_term 
	similarity = torch.exp(-sweight * (soft_dtw)) 

	return similarity  

def dpp_kernel_build(quality_vector, prosody_vector, d_len, mask_idxs, sweight=0.02):
	''' quality_vector.shape = [b,2+nc,1]
		duraction_vector.shape = [b,2+nc,t,1]
		d_len.shape = [2+nc, b] 
		qweight : the weight for quality factors 
		sweight : the weight for similarity factors
		output.shape = [B,2+nc,2+nc]'''

	b, num_dv= prosody_vector.size(0), prosody_vector.size(1)
	kernel = torch.zeros(b, num_dv, num_dv).to(prosody_vector.device, dtype=prosody_vector.dtype)
	#s_kernel = torch.zeros(b, num_dv, num_dv).to(prosody_vector.device, dtype=prosody_vector.dtype)
	sdtw = SoftDTW(use_cuda=True, gamma=1.0, dist_func=l1_dist)

	for i in range(num_dv):
		for j in range(num_dv):
			d1 = prosody_vector[:,i,:,:]      			# [B,pad_len,1]
			d2 = prosody_vector[:,j,:,:]
			nc_term = d_len[i,:]  + d_len[j,:] 
			kernel[:,i,j] = quality_vector[:,i,0] * similarity(d1, d2, sdtw, nc_term, d_len[i,:], d_len[j,:], sweight) * quality_vector[:,j,0]
			#s_kernel[:,i,j] = similarity(d1, d2, sdtw, nc_term, d_len[i,:], d_len[j,:]) 

	#print(torch.eig(s_kernel[0,1:,1:])[0])
	kernel_mask = torch.zeros_like(kernel).bool()
	for i in mask_idxs:
		for j in range(num_dv):
			kernel_mask[i,0,j] = True
		for k in range(num_dv):
			kernel_mask[i,k,0] = True 

	return kernel.masked_fill(kernel_mask, 0.0), kernel_mask      


if __name__ == '__main__':
	'''
	print("---Sanity check---")
	b , num_can, pad_len = 10, 4, 10 
	quality_vector = torch.randn(b, 2+num_can,1).cuda() 
	quality_vector.requires_grad = True 
	prosody_vector = torch.randn(b, 2+num_can, pad_len, 1).cuda()
	prosody_vector.requires_grad = True 
	d_len = torch.randint(1,6,(num_can+2,b)).cuda()
	mask_idxs = [2,5]
	d_len[0,2] = 0      

	kernel, kernel_mask = dpp_kernel_build(quality_vector, prosody_vector, d_len, mask_idxs)
	print("---Sanity check done!---")
	'''
	'''
	a = [ [2,3,4,0,0], [1,0,0,0,0]]
	b = [ [-1,2,3,0,0], [5,-2,1,0,0]]

	a_ = [[2,3,4], [1]]
	b_ = [[-1,2,3], [5,-2,1]]


	a = torch.Tensor(a).unsqueeze(-1).cuda()
	b = torch.Tensor(b).unsqueeze(-1).cuda()
	va = torch.Tensor(a_[0]).unsqueeze(0).unsqueeze(-1).cuda()
	vb = torch.Tensor(b_[0]).unsqueeze(0).unsqueeze(-1).cuda()
	va_2 = torch.Tensor(a_[1]).unsqueeze(0).unsqueeze(-1).cuda()
	vb_2 = torch.Tensor(b_[1]).unsqueeze(0).unsqueeze(-1).cuda()

	a.requires_grad = True 
	b.requires_grad = True 
	va.requires_grad = True  
	vb.requires_grad = True 
	va_2.requires_grad = True 
	vb_2.requires_grad = True 

	a_len = torch.Tensor([3,1]).cuda()
	b_len = torch.Tensor([3,3]).cuda()
	
	a1_len = torch.Tensor([3]).cuda()
	a2_len = torch.Tensor([1]).cuda()
	b1_len = torch.Tensor([3]).cuda()
	b2_len = torch.Tensor([3]).cuda()

	sdtw = SoftDTW(use_cuda=True, gamma=0.1, dist_func=l1_dist)
	v_sdtw = v_SoftDTW(use_cuda=True, gamma=0.1, dist_func=l1_dist)
	
	dis = sdtw(a, b, a_len, b_len)
	v_dis1 = v_sdtw(va, vb, a1_len, b1_len)
	v_dis2 = v_sdtw(va_2, vb_2, a2_len, b2_len)

	dis = torch.sum(dis)
	dis.backward()
	v_dis1.backward()
	v_dis2.backward()
	print(dis)
	print(v_dis1, v_dis2)
'''