import torch
import torch.nn as nn
import numpy as np
import tools
import math
from pointnets import PointNet_cls, PointNet2_cls,PointNet2_seg,PointNet_seg, PointTransformer_cls, PointTransformer_seg

class expJ(torch.autograd.Function):

    @staticmethod
    def forward(ctx, w):
        '''

        :param w: [b,3,3,3]
        :return: exp{J(w)}: [b,3,3]
        '''
        ctx.save_for_backward(w)
        b = w.shape[0]
        theta = w.norm(dim=1)
        #print(theta[0])
        #theta = torch.where(t>math.pi/16, torch.Tensor([math.pi/16]).cuda(), t)
        wnorm = w / (w.norm(dim=1,keepdim=True)+0.001)
        #wnorm = torch.nn.functional.normalize(w,dim=1)
        I = torch.eye(3, device=w.get_device()).repeat(b, 1, 1)
        help1 = torch.zeros((b,1,3, 3), device=w.get_device())
        help2 = torch.zeros((b,1,3, 3), device=w.get_device())
        help3 = torch.zeros((b,1,3, 3), device=w.get_device())
        help1[:,:,1, 2] = -1
        help1[:,:,2, 1] = 1
        help2[:,:,0, 2] = 1
        help2[:,:,2, 0] = -1
        help3[:,:,0, 1] = -1
        help3[:,:,1, 0] = 1
        Jwnorm = (torch.cat([help1,help2,help3],1)*wnorm).sum(dim=1)

        return I + torch.sin(theta) * Jwnorm + (1 - torch.cos(theta)) * torch.bmm(Jwnorm, Jwnorm)

logger = 0
def logger_init(ll):
    global logger
    logger = ll
    print('logger init')

class RPMG(torch.autograd.Function):
    '''
    r0 [b, 3, 3]
    '''
    @staticmethod
    def forward(ctx, in_nd, lr, lam, rgt, iter):
        proj_kind = in_nd.shape[1]
        if proj_kind == 6:
            r0 = tools.compute_rotation_matrix_from_ortho6d(in_nd)
        elif proj_kind == 9:
            r0 = tools.symmetric_orthogonalization(in_nd)
        elif proj_kind == 4:
            r0 = tools.compute_rotation_matrix_from_quaternion(in_nd)
        else:
            raise NotImplementedError
        ctx.save_for_backward(in_nd, r0, torch.Tensor([lr,lam, iter]), rgt)
        return r0

    @staticmethod
    def backward(ctx, grad_in):
        in_nd, r0, config,rgt,  = ctx.saved_tensors
        lr = config[0]
        lam = config[1]
        b = r0.shape[0]
        iter = config[2]
        proj_kind = in_nd.shape[1]
        if lr == -1:
            r_new = rgt
        else:
            Jx = torch.zeros((b, 3, 3)).cuda()
            Jx[:, 2, 1] = 1
            Jx[:, 1, 2] = -1
            Jy = torch.zeros((b, 3, 3)).cuda()
            Jy[:, 0, 2] = 1
            Jy[:, 2, 0] = -1
            Jz = torch.zeros((b, 3, 3)).cuda()
            Jz[:, 0, 1] = -1
            Jz[:, 1, 0] = 1
            gx = (grad_in*torch.bmm(r0, Jx)).reshape(-1,9).sum(dim=1,keepdim=True)
            gy = (grad_in * torch.bmm(r0, Jy)).reshape(-1, 9).sum(dim=1,keepdim=True)
            gz = (grad_in * torch.bmm(r0, Jz)).reshape(-1, 9).sum(dim=1,keepdim=True)
            g = torch.cat([gx,gy,gz],1)
            delta_w = -lr * g
            r_new = torch.bmm(r0, expJ.apply(delta_w.unsqueeze(2).unsqueeze(3).repeat(1, 1, 3, 3)))
            if iter % 100 == 0:
                logger.add_scalar('next_goal_angle_mean', delta_w.norm(dim=1).mean(), iter)
                logger.add_scalar('next_goal_angle_max', delta_w.norm(dim=1).max(), iter)
                R0_Rgt = tools.compute_geodesic_distance_from_two_matrices(r0, rgt)
                logger.add_scalar('r0_rgt_angle', R0_Rgt.mean(), iter)
        if proj_kind == 6:
            r_proj_1 = (r_new[:, :, 0] * in_nd[:, :3]).sum(dim=1, keepdim=True) * r_new[:, :, 0]
            r_proj_2 = (r_new[:, :, 0] * in_nd[:, 3:]).sum(dim=1, keepdim=True) * r_new[:, :, 0] \
                      + (r_new[:, :, 1] * in_nd[:, 3:]).sum(dim=1, keepdim=True) * r_new[:, :, 1]
            r_reg_1 = lam * (r_proj_1 - r_new[:, :, 0])
            r_reg_2 = lam * (r_proj_2 - r_new[:, :, 1])
            gradient_nd = torch.cat([in_nd[:, :3] - r_proj_1 + r_reg_1, in_nd[:, 3:] - r_proj_2 + r_reg_2], 1)
        elif proj_kind == 9:
            SVD_proj = tools.compute_SVD_nearest_Mnlsew(in_nd.reshape(-1,3,3), r_new)
            gradient_nd = in_nd - SVD_proj + lam * (SVD_proj - r_new.reshape(-1,9))
            R_proj_g = tools.symmetric_orthogonalization(SVD_proj)
            if iter % 100 == 0:
                logger.add_scalar('9d_reflection', (((R_proj_g-r_new).reshape(-1,9).abs().sum(dim=1))>5e-1).sum(), iter)
                logger.add_scalar('reg', (SVD_proj - r_new.reshape(-1, 9)).norm(dim=1).mean(), iter)
                logger.add_scalar('main', (in_nd - SVD_proj).norm(dim=1).mean(), iter)
        elif proj_kind == 4:
            q_1 = tools.compute_quaternions_from_rotation_matrices(r_new)
            q_2 = -q_1
            normalized_nd = tools.normalize_vector(in_nd)
            q_new = torch.where(
                (q_1 - normalized_nd).norm(dim=1, keepdim=True) < (q_2 - normalized_nd).norm(dim=1, keepdim=True),
                q_1, q_2)
            q_proj = (in_nd * q_new).sum(dim=1, keepdim=True) * q_new
            gradient_nd = in_nd - q_proj + lam * (q_proj - q_new)

        return gradient_nd, None, None,None,None,None

class Model(nn.Module):
    def __init__(self, out_rotation_mode="Quaternion", kind=1):
        super(Model, self).__init__()
        
        self.out_rotation_mode = out_rotation_mode
        
        if(out_rotation_mode == "Quaternion"):
            self.out_channel = 4
        elif (out_rotation_mode  == "ortho6d"):
            self.out_channel = 6
        elif (out_rotation_mode  == "svd9d"):
            self.out_channel = 9
        elif (out_rotation_mode  == "10d"):
            self.out_channel = 10
        elif out_rotation_mode == 'euler':
            self.out_channel = 3
        elif out_rotation_mode == 'axisangle':
            self.out_channel = 4
        else:
            raise NotImplementedError

        print(out_rotation_mode)

        self.kind = kind
        if kind == 1:
            self.model = PointNet_cls(self.out_channel)
        elif kind == 2:
            self.model = PointNet2_seg(self.out_channel)
        elif kind == 3:
            self.model = PointNet2_cls(self.out_channel)
        elif kind == 4:
            self.model = PointNet_seg(self.out_channel)
        elif kind == 5:
            self.model = PointTransformer_cls(self.out_channel)
        elif kind == 6:
            self.model = PointTransformer_seg(self.out_channel)
        else:
            raise NotImplementedError

    #pt b*point_num*3
    def forward(self, input):
        batch = input.shape[0]
        p = input.shape[2]
        out_data = self.model(input)
        if self.kind == 1 or self.kind == 3 or self.kind==5:
            out_nd = out_data
            mean_nd = out_data
        elif self.kind == 2 or self.kind == 4 or self.kind == 6:
            out_nd = out_data.transpose(1,2).reshape(batch*p,-1)
            mean_nd = out_data.transpose(1,2).mean(dim=1)

        if(self.out_rotation_mode == "Quaternion"):
            out_rmat = tools.compute_rotation_matrix_from_quaternion(out_nd) #b*3*3
            mean_rmat = tools.compute_rotation_matrix_from_quaternion(mean_nd)
        elif(self.out_rotation_mode=="ortho6d"):
            out_rmat = tools.compute_rotation_matrix_from_ortho6d(out_nd) #b*3*3
            mean_rmat = tools.compute_rotation_matrix_from_ortho6d(mean_nd)
        elif(self.out_rotation_mode=="svd9d"):
            out_rmat = tools.symmetric_orthogonalization(out_nd)  # b*3*3
            mean_rmat = tools.symmetric_orthogonalization(mean_nd)
        elif (self.out_rotation_mode == "10d"):
            out_rmat = tools.compute_rotation_matrix_from_10d(out_nd)  # b*3*3
            mean_rmat = tools.compute_rotation_matrix_from_10d(mean_nd)
        elif (self.out_rotation_mode == "euler"):
            out_rmat = tools.compute_rotation_matrix_from_euler(out_nd)  # b*3*3
            mean_rmat = tools.compute_rotation_matrix_from_euler(mean_nd)
        elif (self.out_rotation_mode == "axisangle"):
            out_rmat = tools.compute_rotation_matrix_from_axisAngle(out_nd)  # b*3*3
            mean_rmat = tools.compute_rotation_matrix_from_axisAngle(mean_nd)

        return out_rmat, out_nd, mean_rmat, mean_nd


        
        
        
        
        
        
        
        
    
