import torch
import math
from utils.sh import SHEval



def hypot(x, y, z):
	return torch.sqrt(x**2+y**2+z**2)




def fromRotation(rad, axis):
    out = torch.zeros([16])
    x = axis[0]
    y = axis[1]
    z = axis[2]#
    len = hypot(x, y, z)#
    # s, c, t#
    # if (len < glMatrix.EPSILON) :
    #   return null#
    
    len = 1 / len#
    x *= len#
    y *= len#
    z *= len#
    s = torch.sin(rad)#
    c = torch.cos(rad)#
    t = 1 - c#
    # Perform rotation-specific matrix multiplication
    out[0] = x * x * t + c#
    out[1] = y * x * t + z * s#
    out[2] = z * x * t - y * s#
    out[3] = 0#
    out[4] = x * y * t - z * s#
    out[5] = y * y * t + c#
    out[6] = z * y * t + x * s#
    out[7] = 0#
    out[8] = x * z * t + y * s#
    out[9] = y * z * t - x * s#
    out[10] = z * z * t + c#
    out[11] = 0#
    out[12] = 0#
    out[13] = 0#
    out[14] = 0#
    out[15] = 1#
    return out




def getRotationPrecomputeL(precompute_L, rotationMatrix):
	# rotationMatrix_inverse = mat4.create()
	rotationMatrix_inverse = torch.zeros([4,4])
	rotationMatrix_inverse = torch.inverse(rotationMatrix)
	# mat4.invert(rotationMatrix_inverse, rotationMatrix)
	# r = mat4Matrix2mathMatrix(rotationMatrix_inverse)
	r = rotationMatrix_inverse
	shRotateMatrix3x3 = computeSquareMatrix_3by3(r)
	shRotateMatrix5x5 = computeSquareMatrix_5by5(r)

	# result = []
	# for i in range(9):
	# 	result[i] = []
	result = torch.zeros([9, 3])
	
	for i in range(3):
		# import pdb; pdb.set_trace()
		L_SH_R_3 = torch.mm(
            precompute_L[1:4, i].unsqueeze(0), shRotateMatrix3x3
        )
		L_SH_R_5 = torch.mm(
			precompute_L[4:9, i].unsqueeze(0), shRotateMatrix5x5
		)

		# L_SH_R_3 = math.multiply([precompute_L[1][i], precompute_L[2][i], precompute_L[3][i]], shRotateMatrix3x3)
		# L_SH_R_5 = math.multiply([precompute_L[4][i], precompute_L[5][i], precompute_L[6][i], precompute_L[7][i], precompute_L[8][i]], shRotateMatrix5x5)
	
		result[0][i] = precompute_L[0][i]
		# import pdb# pdb.set_trace()
		# result[1][i] = L_SH_R_3._data[0]
		# result[2][i] = L_SH_R_3._data[1]
		# result[3][i] = L_SH_R_3._data[2]
		result[1:4, i] = L_SH_R_3
		result[4:9, i] = L_SH_R_5

		# result[4][i] = L_SH_R_5._data[0]
		# result[5][i] = L_SH_R_5._data[1]
		# result[6][i] = L_SH_R_5._data[2]
		# result[7][i] = L_SH_R_5._data[3]
		# result[8][i] = L_SH_R_5._data[4]


	return result


def computeSquareMatrix_3by3(rotationMatrix):
	## 计算方阵SA(-1) 3*3 

	# 1、pick ni - {ni}
	n1 = [1, 0, 0, 0] 
	n2 = [0, 0, 1, 0] 
	n3 = [0, 1, 0, 0]

	# 2、{P(ni)} - A  A_inverse
	n1_sh = SHEval(n1[0], n1[1], n1[2], 3)
	n2_sh = SHEval(n2[0], n2[1], n2[2], 3)
	n3_sh = SHEval(n3[0], n3[1], n3[2], 3)

	# A = math.matrix(
	# [
	# 	[n1_sh[1], n2_sh[1], n3_sh[1]], 
	# 	[n1_sh[2], n2_sh[2], n3_sh[2]], 
	# 	[n1_sh[3], n2_sh[3], n3_sh[3]], 
	# ])
	A = [
		[n1_sh[1], n2_sh[1], n3_sh[1]], 
		[n1_sh[2], n2_sh[2], n3_sh[2]], 
		[n1_sh[3], n2_sh[3], n3_sh[3]], 
    ]

	A_inverse = torch.inverse(torch.tensor(A))

	# 3、用 R 旋转 ni - {R(ni)}
	
	n1_r = torch.mm(rotationMatrix, torch.tensor(n1, dtype=torch.float32).unsqueeze(1))
	n2_r = torch.mm(rotationMatrix, torch.tensor(n2, dtype=torch.float32).unsqueeze(1))
	n3_r = torch.mm(rotationMatrix, torch.tensor(n3, dtype=torch.float32).unsqueeze(1))

	# 4、R(ni) SH投影 - S

	n1_r_sh = SHEval(n1_r[0].item(), n1_r[1].item(), n1_r[2].item(), 3)
	n2_r_sh = SHEval(n2_r[0].item(), n2_r[1].item(), n2_r[2].item(), 3)
	n3_r_sh = SHEval(n3_r[0].item(), n3_r[1].item(), n3_r[2].item(), 3)

	S = torch.tensor(
	[
		[n1_r_sh[1], n2_r_sh[1], n3_r_sh[1]], 
		[n1_r_sh[2], n2_r_sh[2], n3_r_sh[2]], 
		[n1_r_sh[3], n2_r_sh[3], n3_r_sh[3]], 

	])

	# 5、S*A_inverse
	# return math.multiply(S, A_inverse)
	 
	return torch.mm(S, A_inverse)  



def computeSquareMatrix_5by5(rotationMatrix):
	# 计算方阵SA(-1) 5*5
	
	# 1、pick ni - {ni}
	k = 1 / math.sqrt(2)
	n1 = [1, 0, 0, 0] 
	n2 = [0, 0, 1, 0] 
	n3 = [k, k, 0, 0] 
	
	n4 = [k, 0, k, 0] 
	n5 = [0, k, k, 0]

	# 2、{P(ni)} - A  A_inverse
	n1_sh = SHEval(n1[0], n1[1], n1[2], 3)
	n2_sh = SHEval(n2[0], n2[1], n2[2], 3)
	n3_sh = SHEval(n3[0], n3[1], n3[2], 3)
	n4_sh = SHEval(n4[0], n4[1], n4[2], 3)
	n5_sh = SHEval(n5[0], n5[1], n5[2], 3)

	# 
    # A = math.matrix(
	A = torch.tensor(
	[
		[n1_sh[4], n2_sh[4], n3_sh[4], n4_sh[4], n5_sh[4]], 
		[n1_sh[5], n2_sh[5], n3_sh[5], n4_sh[5], n5_sh[5]], 
		[n1_sh[6], n2_sh[6], n3_sh[6], n4_sh[6], n5_sh[6]], 
		[n1_sh[7], n2_sh[7], n3_sh[7], n4_sh[7], n5_sh[7]], 
		[n1_sh[8], n2_sh[8], n3_sh[8], n4_sh[8], n5_sh[8]], 
	])
	
	# A_inverse = math.inv(A)
	A_inverse = torch.inverse(A)

	# 3、用 R 旋转 ni - {R(ni)}
	n1_r = torch.mm(rotationMatrix, torch.tensor(n1, dtype=torch.float32).unsqueeze(1))
	n2_r = torch.mm(rotationMatrix, torch.tensor(n2, dtype=torch.float32).unsqueeze(1))
	n3_r = torch.mm(rotationMatrix, torch.tensor(n3, dtype=torch.float32).unsqueeze(1))
	n4_r = torch.mm(rotationMatrix, torch.tensor(n4, dtype=torch.float32).unsqueeze(1))
	n5_r = torch.mm(rotationMatrix, torch.tensor(n5, dtype=torch.float32).unsqueeze(1))

	# 4、R(ni) SH投影 - S
	n1_r_sh = SHEval(n1_r[0].item(), n1_r[1].item(), n1_r[2].item(), 3)
	n2_r_sh = SHEval(n2_r[0].item(), n2_r[1].item(), n2_r[2].item(), 3)
	n3_r_sh = SHEval(n3_r[0].item(), n3_r[1].item(), n3_r[2].item(), 3)
	n4_r_sh = SHEval(n4_r[0].item(), n4_r[1].item(), n4_r[2].item(), 3)
	n5_r_sh = SHEval(n5_r[0].item(), n5_r[1].item(), n5_r[2].item(), 3)

	S = torch.tensor(
	[	
		[n1_r_sh[4], n2_r_sh[4], n3_r_sh[4], n4_r_sh[4], n5_r_sh[4]], 
		[n1_r_sh[5], n2_r_sh[5], n3_r_sh[5], n4_r_sh[5], n5_r_sh[5]], 
		[n1_r_sh[6], n2_r_sh[6], n3_r_sh[6], n4_r_sh[6], n5_r_sh[6]], 
		[n1_r_sh[7], n2_r_sh[7], n3_r_sh[7], n4_r_sh[7], n5_r_sh[7]], 
		[n1_r_sh[8], n2_r_sh[8], n3_r_sh[8], n4_r_sh[8], n5_r_sh[8]], 
	])

	# 5、S*A_inverse
	# return math.multiply(S, A_inverse)
	return torch.mm(S, A_inverse)    


def mat4Matrix2mathMatrix(rotationMatrix):

	# mathMatrix = []
	mathMatrix = torch.zeros([4,4])
	
	for i in range(4):
		
		for j in range(4):
			mathMatrix[i, j] = rotationMatrix[i*4+j]
		
	# 	mathMatrix.push(r)
	
	# Edit Start
	#return math.matrix(mathMatrix)
	return mathMatrix.transpose(0,1)
	# Edit End

if __name__ == '__main__':
	# r = torch.eye(4)
	r = fromRotation(torch.tensor(0.5), torch.tensor([0, 1, 0], dtype = torch.float32))
	r = mat4Matrix2mathMatrix(r)
	# r.view()
	# res = computeSquareMatrix_3by3(r)
# 	light_shs=torch.tensor([[0.84866, -0.578497, 0.0247782, -0.0661633, 0.051255, -0.00724286, 0.0207288, -0.00976639, -0.0628752],
# [1.19363, -0.771226, 0.0366317, -0.0985846, 0.0794252, -0.0100492, 0.0265743, -0.0171396, -0.076629],
# [1.55631, -1.16267, 0.0444034, -0.142632, 0.113748, -0.0212884, -0.00654665, -0.0309487, -0.144114]])  
	
# light_shs=torch.tensor([[ 0.8487,   1.1936 ,  1.5563],
# [-0.5785,  -0.7712 , -1.1627],
# [-0.0423,  -0.0632,  -0.0960],
# [-0.0566,  -0.0841,  -0.1144],
# [ 0.0338,   0.0514,   0.0794],
# [ 0.0392,   0.0614,   0.0842],
# [-0.0475,  -0.0621,  -0.1123],
# [-0.0408,  -0.0486,  -0.0475],
# [-0.0235,  -0.0254,  -0.0830]])
# # l = light_shs.transpose(0,1) 
# 	# print(l)
# print(getRotationPrecomputeL(light_shs, r))
        # [[ 1.6474,  1.2585,  0.7744],
        #  [-0.2088, -0.2499, -0.2270],
        #  [-0.0356, -0.0149,  0.0050],
        #  [ 0.2138,  0.2172,  0.2092],
        #  [-0.0285, -0.0473, -0.0561],
        #  [ 0.1723,  0.1651,  0.1214],
        #  [ 0.0130,  0.0104, -0.0360],
        #  [-0.0608, -0.0714, -0.0678],
        #  [-0.0383,  0.0157,  0.0399]]
# pq = torch.zeros([9,3])
# print(pq)

# [0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.]

