import torch
import torch.nn as nn
import torch.nn.functional as F
import math

EPS = 1e-8

def quat_normalize(q, eps=1e-8):
    return q / (q.norm(dim=-1, keepdim=True) + eps)

def quat_mul(q1, q2):

    w1,x1,y1,z1 = q1.unbind(-1)
    w2,x2,y2,z2 = q2.unbind(-1)
    return torch.stack([
        w1*w2 - x1*x2 - y1*y2 - z1*z2,
        w1*x2 + x1*w2 + y1*z2 - z1*y2,
        w1*y2 - x1*z2 + y1*w2 + z1*x2,
        w1*z2 + x1*y2 - y1*x2 + z1*w2
    ], dim=-1)

def axis_angle_to_quat(axis, angle):

    axis = F.normalize(axis, dim=-1)
    if not torch.is_tensor(angle):
        angle = torch.full((axis.size(0),), float(angle), device=axis.device, dtype=axis.dtype)
    half = 0.5 * angle
    s = torch.sin(half).unsqueeze(-1)
    c = torch.cos(half).unsqueeze(-1)
    q = torch.cat([c, axis * s], dim=-1)
    return quat_normalize(q)

def quat_to_rotmat(q):
    q = quat_normalize(q)
    w,x,y,z = q.unbind(-1)
    B = q.size(0)
    R = q.new_zeros(B,3,3)
    R[:,0,0] = 1 - 2*(y*y + z*z)
    R[:,0,1] = 2*(x*y - z*w)
    R[:,0,2] = 2*(x*z + y*w)
    R[:,1,0] = 2*(x*y + z*w)
    R[:,1,1] = 1 - 2*(x*x + z*z)
    R[:,1,2] = 2*(y*z - x*w)
    R[:,2,0] = 2*(x*z - y*w)
    R[:,2,1] = 2*(y*z + x*w)
    R[:,2,2] = 1 - 2*(x*x + y*y)
    return R

def rotation_matrix_to_quaternion_safe(R, eps=1e-8):
    B = R.shape[0]
    q = R.new_zeros(B, 4)
    trace = R[:,0,0] + R[:,1,1] + R[:,2,2]
    cond = trace > 0
    t = torch.zeros(B, device=R.device, dtype=R.dtype)

    t[cond] = torch.sqrt(trace[cond] + 1.0 + eps) * 2.0
    q[cond,0] = 0.25 * t[cond]
    q[cond,1] = (R[cond,2,1] - R[cond,1,2]) / (t[cond] + eps)
    q[cond,2] = (R[cond,0,2] - R[cond,2,0]) / (t[cond] + eps)
    q[cond,3] = (R[cond,1,0] - R[cond,0,1]) / (t[cond] + eps)

    cond2 = ~cond & (R[:,0,0] > R[:,1,1]) & (R[:,0,0] > R[:,2,2])
    t[cond2] = torch.sqrt(1.0 + R[cond2,0,0] - R[cond2,1,1] - R[cond2,2,2] + eps) * 2.0
    q[cond2,0] = (R[cond2,2,1] - R[cond2,1,2]) / (t[cond2] + eps)
    q[cond2,1] = 0.25 * t[cond2]
    q[cond2,2] = (R[cond2,0,1] + R[cond2,1,0]) / (t[cond2] + eps)
    q[cond2,3] = (R[cond2,0,2] + R[cond2,2,0]) / (t[cond2] + eps)
    cond3 = ~cond & ~cond2 & (R[:,1,1] > R[:,2,2])
    t[cond3] = torch.sqrt(1.0 + R[cond3,1,1] - R[cond3,0,0] - R[cond3,2,2] + eps) * 2.0
    q[cond3,0] = (R[cond3,0,2] - R[cond3,2,0]) / (t[cond3] + eps)
    q[cond3,1] = (R[cond3,0,1] + R[cond3,1,0]) / (t[cond3] + eps)
    q[cond3,2] = 0.25 * t[cond3]
    q[cond3,3] = (R[cond3,1,2] + R[cond3,2,1]) / (t[cond3] + eps)
    cond4 = ~(cond | cond2 | cond3)
    t[cond4] = torch.sqrt(1.0 + R[cond4,2,2] - R[cond4,0,0] - R[cond4,1,1] + eps) * 2.0
    q[cond4,0] = (R[cond4,1,0] - R[cond4,0,1]) / (t[cond4] + eps)
    q[cond4,1] = (R[cond4,0,2] + R[cond4,2,0]) / (t[cond4] + eps)
    q[cond4,2] = (R[cond4,1,2] + R[cond4,2,1]) / (t[cond4] + eps)
    q[cond4,3] = 0.25 * t[cond4]
    return quat_normalize(q)

def reflect_matrix_from_normal(n):
    n = F.normalize(n, dim=-1)
    I = torch.eye(3, device=n.device, dtype=n.dtype).unsqueeze(0).expand(n.size(0), -1, -1)
    nnT = n.unsqueeze(-1) @ n.unsqueeze(-2)
    return I - 2.0 * nnT


def exp_map_r3_to_quat(xi, eps=EPS, max_angle=(math.pi - 1e-6)):

    orig_shape = xi.shape

    theta = xi.norm(dim=-1, keepdim=True)

    theta_clamped = torch.clamp(theta, max=max_angle)

    axis = torch.where(theta_clamped > eps, xi / (theta_clamped + eps), torch.zeros_like(xi))

    half = 0.5 * theta_clamped

    half_sq = half * half
    small_mask = (half.abs() < 1e-4)

    sinc_half = torch.where(
        small_mask,
        1.0 - half_sq / 6.0 + (half_sq * half_sq) / 120.0,
        torch.sin(half) / (half + EPS)
    )
    sin_half = (sinc_half * half)
    cos_half = torch.cos(half_clamped := half)


    qv = axis * sin_half
    q = torch.cat([cos_half, qv], dim=-1)
    q = quat_normalize(q, eps=eps)
    q = torch.nan_to_num(q, nan=0.0, posinf=0.0, neginf=0.0)
    return q










