import torch
from torch.distributions import Normal
import numpy as np

def euler_to_SO3(angs):
    """
    angs: Tensor of shape [B, 3] => azp, altp, phip
    returns: [B, 3, 3] rotation matrices
    """
    azp, altp, phip = angs[:, 0], angs[:, 1], angs[:, 2]

    cos = torch.cos
    sin = torch.sin

    row1 = torch.stack([
        cos(phip) * cos(azp) - cos(altp) * sin(azp) * sin(phip),
        cos(phip) * sin(azp) + cos(altp) * cos(azp) * sin(phip),
        sin(altp) * sin(phip)
    ], dim=1)

    row2 = torch.stack([
        -sin(phip) * cos(azp) - cos(altp) * sin(azp) * cos(phip),
        -sin(phip) * sin(azp) + cos(altp) * cos(azp) * cos(phip),
        sin(altp) * cos(phip)
    ], dim=1)

    row3 = torch.stack([
        sin(altp) * sin(azp),
        -sin(altp) * cos(azp),
        cos(altp)
    ], dim=1)

    R = torch.stack([row1, row2, row3], dim=1)  # shape [B, 3, 3]
    return R.to(torch.float32)
    
def axis_angle_to_SO3(axis_angle):
    r"""Convert 3d vector of axis-angle rotation to 3x3 rotation matrix.

    Args:
        axis_angle: tensor of 3d vector of axis-angle rotations in radians with shape :math:`(N, 3)`.

    Returns:
        tensor of rotation matrices of shape :math:`(N, 3, 3)`.

    Example:
        >>> input = tensor([[0., 0., 0.]])
        >>> axis_angle_to_rotation_matrix(input)
        tensor([[[1., 0., 0.],
                 [0., 1., 0.],
                 [0., 0., 1.]]])

        >>> input = tensor([[1.5708, 0., 0.]])
        >>> axis_angle_to_rotation_matrix(input)
        tensor([[[ 1.0000e+00,  0.0000e+00,  0.0000e+00],
                 [ 0.0000e+00, -3.6200e-06, -1.0000e+00],
                 [ 0.0000e+00,  1.0000e+00, -3.6200e-06]]])

    """
    if not axis_angle.shape[-1] == 3:
        raise ValueError(f"Input size must be a (*, 3) tensor. Got {axis_angle.shape}")

    def _compute_rotation_matrix(axis_angle, theta2, eps = 1e-6):
        # We want to be careful to only evaluate the square root if the
        # norm of the axis_angle vector is greater than zero. Otherwise
        # we get a division by zero.
        k_one = 1.0
        theta = torch.sqrt(theta2)
        wxyz = axis_angle / (theta + eps)
        wx, wy, wz = torch.chunk(wxyz, 3, dim=1)
        cos_theta = torch.cos(theta)
        sin_theta = torch.sin(theta)

        r00 = cos_theta + wx * wx * (k_one - cos_theta)
        r10 = wz * sin_theta + wx * wy * (k_one - cos_theta)
        r20 = -wy * sin_theta + wx * wz * (k_one - cos_theta)
        r01 = wx * wy * (k_one - cos_theta) - wz * sin_theta
        r11 = cos_theta + wy * wy * (k_one - cos_theta)
        r21 = wx * sin_theta + wy * wz * (k_one - cos_theta)
        r02 = wy * sin_theta + wx * wz * (k_one - cos_theta)
        r12 = -wx * sin_theta + wy * wz * (k_one - cos_theta)
        r22 = cos_theta + wz * wz * (k_one - cos_theta)
        rotation_matrix = torch.concatenate([r00, r01, r02, r10, r11, r12, r20, r21, r22], dim=1)
        return rotation_matrix.view(-1, 3, 3)

    _axis_angle = torch.unsqueeze(axis_angle, dim=1)
    theta2 = torch.matmul(_axis_angle, _axis_angle.transpose(1, 2))
    theta2 = torch.squeeze(theta2, dim=1)

    # compute rotation matrices
    rotation_matrix = _compute_rotation_matrix(axis_angle, theta2)
    return rotation_matrix  # Nx3x3

def map_to_lie_algebra(v):
    """Map a point in R^N to the tangent space at the identity, i.e.
    to the Lie Algebra
    Arg:
        v = vector in R^N, (..., 3) in our case
    Return:
        R = v converted to Lie Algebra element, (3,3) in our case"""

    # make sure this is a sample from R^3
    assert v.size()[-1] == 3

    R_x = v.new_tensor([[ 0., 0., 0.],
                        [ 0., 0.,-1.],
                        [ 0., 1., 0.]])

    R_y = v.new_tensor([[ 0., 0., 1.],
                        [ 0., 0., 0.],
                        [-1., 0., 0.]])

    R_z = v.new_tensor([[ 0.,-1., 0.],
                        [ 1., 0., 0.],
                        [ 0., 0., 0.]])

    R = R_x * v[..., 0, None, None] + \
        R_y * v[..., 1, None, None] + \
        R_z * v[..., 2, None, None]
    return R

def expmap(v):
    theta = v.norm(p=2, dim=-1, keepdim=True)
    # normalize K
    K = map_to_lie_algebra(v / theta)

    I = torch.eye(3, device=v.device, dtype=v.dtype)
    R = I + torch.sin(theta)[..., None]*K \
        + (1. - torch.cos(theta))[..., None]*(K@K)
    return R

def s2s1rodrigues(s2_el, s1_el):
    K = map_to_lie_algebra(s2_el)
    cos_theta = s1_el[...,0]
    sin_theta = s1_el[...,1]
    I = torch.eye(3, device=s2_el.device, dtype=s2_el.dtype)
    R = I + sin_theta[..., None, None]*K \
        + (1. - cos_theta)[..., None, None]*(K@K)
    return R

def s2s2_to_SO3(v1, v2=None):
    '''Normalize 2 3-vectors. Project second to orthogonal component.
    Take cross product for third. Stack to form SO matrix.'''
    if v2 is None:
        assert v1.shape[-1] == 6
        v2 = v1[...,3:]
        v1 = v1[...,0:3]
    e1 = torch.nn.functional.normalize(v1, dim=-1)
    e2 = v2 - (e1 * v2).sum(-1, keepdim=True) * e1
    e2 = torch.nn.functional.normalize(e2, dim=-1)
    e3 = torch.cross(e1, e2, dim=-1)
    return torch.stack([e1, e2, e3], dim=-2)

def SO3_to_s2s2(r):
    '''Map batch of SO(3) matrices to s2s2 representation as first two
    basis vectors, concatenated as Bx6'''
    return r.view(*r.shape[:-2],9)[...,:6].contiguous()

def SO3_to_quaternions(r):
    """Map batch of SO(3) matrices to quaternions."""
    batch_dims = r.shape[:-2]
    assert list(r.shape[-2:]) == [3, 3], 'Input must be 3x3 matrices'
    r = r.view(-1, 3, 3)
    n = r.shape[0]

    diags = [r[:, 0, 0], r[:, 1, 1], r[:, 2, 2]]
    denom_pre = torch.stack([
        1 + diags[0] - diags[1] - diags[2],
        1 - diags[0] + diags[1] - diags[2],
        1 - diags[0] - diags[1] + diags[2],
        1 + diags[0] + diags[1] + diags[2]
    ], 1)
    denom = 0.5 * torch.sqrt(1E-6 + torch.abs(denom_pre))

    case0 = torch.stack([
        denom[:, 0],
        (r[:, 0, 1] + r[:, 1, 0]) / (4 * denom[:, 0]),
        (r[:, 0, 2] + r[:, 2, 0]) / (4 * denom[:, 0]),
        (r[:, 1, 2] - r[:, 2, 1]) / (4 * denom[:, 0])
    ], 1)
    case1 = torch.stack([
        (r[:, 0, 1] + r[:, 1, 0]) / (4 * denom[:, 1]),
        denom[:, 1],
        (r[:, 1, 2] + r[:, 2, 1]) / (4 * denom[:, 1]),
        (r[:, 2, 0] - r[:, 0, 2]) / (4 * denom[:, 1])
    ], 1)
    case2 = torch.stack([
        (r[:, 0, 2] + r[:, 2, 0]) / (4 * denom[:, 2]),
        (r[:, 1, 2] + r[:, 2, 1]) / (4 * denom[:, 2]),
        denom[:, 2],
        (r[:, 0, 1] - r[:, 1, 0]) / (4 * denom[:, 2])
    ], 1)
    case3 = torch.stack([
        (r[:, 1, 2] - r[:, 2, 1]) / (4 * denom[:, 3]),
        (r[:, 2, 0] - r[:, 0, 2]) / (4 * denom[:, 3]),
        (r[:, 0, 1] - r[:, 1, 0]) / (4 * denom[:, 3]),
        denom[:, 3]
    ], 1)

    cases = torch.stack([case0, case1, case2, case3], 1)

    quaternions = cases[torch.arange(n, dtype=torch.long),
                        torch.argmax(denom.detach(), 1)]
    return quaternions.view(*batch_dims, 4)


def quaternions_to_SO3(q):
    '''Normalizes q and maps to group matrix.'''
    q = q / q.norm(p=2, dim=-1, keepdim=True)
    r, i, j, k = q[..., 0], q[..., 1], q[..., 2], q[..., 3]

    return torch.stack([
        r*r - i*i - j*j + k*k, 2*(r*i + j*k), 2*(r*j - i*k),
        2*(r*i - j*k), -r*r + i*i - j*j + k*k, 2*(i*j + r*k),
        2*(r*j + i*k), 2*(i*j - r*k), -r*r - i*i + j*j + k*k
        ], -1).view(*q.shape[:-1], 3, 3)

def random_quaternions(n, dtype=torch.float32, device=None):
    u1, u2, u3 = torch.rand(3, n, dtype=dtype, device=device)
    return torch.stack((
        torch.sqrt(1-u1) * torch.sin(2 * np.pi * u2),
        torch.sqrt(1-u1) * torch.cos(2 * np.pi * u2),
        torch.sqrt(u1) * torch.sin(2 * np.pi * u3),
        torch.sqrt(u1) * torch.cos(2 * np.pi * u3),
    ), 1)

def random_SO3(n, dtype=torch.float32, device=None):
    return quaternions_to_SO3(random_quaternions(n, dtype, device))

def random_S2S2(n, dtype=torch.float32, device=None):
    return SO3_to_s2s2(random_SO3(n, dtype, device))

def random_axis_angle(n, dtype=torch.float32, device=None):
    return quaternion_to_axis_angle(random_quaternions(n, dtype, device))

def random_axis_angle(n, range1, range2=None, dtype=torch.float32, device=None):
    if range2 is not None:
        arng = np.random.choice(np.arange(range1, range2, 1), size=(n,3))
    else:
        arng = np.random.choice(range1, size=(n,3))
    arng=np.deg2rad(arng)
    return torch.from_numpy(arng).to(dtype=dtype, device=device)

def _convert_affinematrix_to_homography_impl(A: torch.Tensor) -> torch.Tensor:
    H: torch.Tensor = torch.nn.functional.pad(A, [0, 0, 0, 1], "constant", value=0.0)
    H[..., -1, -1] += 1.0
    return H


def convert_affinematrix_to_homography3d(A: torch.Tensor) -> torch.Tensor:
    if not isinstance(A, torch.Tensor):
        raise TypeError(f"Input type is not a Tensor. Got {type(A)}")

    if not (len(A.shape) == 3 and A.shape[-2:] == (3, 4)):
        raise ValueError(f"Input matrix must be a Bx3x4 tensor. Got {A.shape}")

    return _convert_affinematrix_to_homography_impl(A)

def _torch_inverse_cast(input: torch.Tensor) -> torch.Tensor:
    if not isinstance(input, torch.Tensor):
        raise AssertionError(f"Input must be Tensor. Got: {type(input)}.")
    dtype: torch.dtype = input.dtype
    if dtype not in (torch.float32, torch.float64):
        dtype = torch.float32
    return torch.linalg.inv(input.to(dtype)).to(input.dtype)


def eye_like(n: int, input: torch.Tensor, shared_memory: bool = False) -> torch.Tensor:
    if n <= 0:
        raise AssertionError(type(n), n)
    if len(input.shape) < 1:
        raise AssertionError(input.shape)

    identity = torch.eye(n, device=input.device).type(input.dtype)

    return identity[None].expand(input.shape[0], n, n) if shared_memory else identity[None].repeat(input.shape[0], 1, 1)


def projection_from_Rt(rmat: torch.Tensor, tvec: torch.Tensor) -> torch.Tensor:
    if not (len(rmat.shape) >= 2 and rmat.shape[-2:] == (3, 3)):
        raise AssertionError(rmat.shape)
    if not (len(tvec.shape) >= 2 and tvec.shape[-2:] == (3, 1)):
        raise AssertionError(tvec.shape)

    return torch.cat([rmat, tvec], -1)  # Bx3x4

def convert_SO3_to_kornia_affine_matrix(mat, center):
    #mat shape (B, 3, 3)
    from_origin_mat = eye_like(4, mat, shared_memory=False)  # Bx4x4
    from_origin_mat[..., :3, -1] += center

    to_origin_mat = from_origin_mat.clone()
    to_origin_mat = _torch_inverse_cast(from_origin_mat)

    proj_mat = projection_from_Rt(mat, torch.zeros_like(center)[..., None])
    proj_mat = convert_affinematrix_to_homography3d(proj_mat)  # Bx4x4
    proj_mat = from_origin_mat @ proj_mat @ to_origin_mat

    affine_matrix = proj_mat[..., :3, :]
    return affine_matrix

def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
    angles = torch.norm(axis_angle, dim=1, keepdim=True)  # Compute the rotation angles
    axis = axis_angle / (angles + 1e-8)  # Normalize to get unit vectors (avoid division by zero)
    
    half_angles = angles * 0.5  # Half of the rotation angles
    sin_half_angles = torch.sin(half_angles)  # Sine of half angles
    cos_half_angles = torch.cos(half_angles)  # Cosine of half angles
    
    qw = cos_half_angles  # Scalar part of the quaternion
    qxyz = axis * sin_half_angles  # Vector part of the quaternion
    
    quaternions = torch.cat([qw, qxyz], dim=1).float()  # Combine scalar and vector parts
    return quaternions

def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
    norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
    half_angles = torch.atan2(norms, quaternions[..., :1])
    angles = 2 * half_angles
    eps = 1e-6
    small_angles = angles.abs() < eps
    sin_half_angles_over_angles = torch.empty_like(angles)
    sin_half_angles_over_angles[~small_angles] = (
        torch.sin(half_angles[~small_angles]) / angles[~small_angles]
    )
    # for x small, sin(x/2) is about x/2 - (x/2)^3/6
    # so sin(x/2)/x is about 1/2 - (x*x)/48
    sin_half_angles_over_angles[small_angles] = (
        0.5 - (angles[small_angles] * angles[small_angles]) / 48
    )
    return quaternions[..., 1:] / sin_half_angles_over_angles

def axis_angle_to_S2S2(angles: torch.Tensor) -> torch.Tensor:
    return SO3_to_s2s2(quaternions_to_SO3(axis_angle_to_quaternion(angles)))

def wrap_to_360(angles):
    return (angles % 360 + 360) % 360

if __name__=='__main__':
    import kornia
    torch.manual_seed(777)
    x = random_axis_angle(10, range1=360)
    x1 = axis_angle_to_quaternion(x)
    print(x1)
    x = quaternion_to_axis_angle(x1)
    print(x.shape)
    x2 = axis_angle_to_quaternion(x)
    print(x2)
    print(torch.allclose(x1,x2))
    exit(0)
    x = random_quaternions(100)
    theta = quaternion_to_axis_angle(x)
    trans = torch.from_numpy(np.random.uniform(low=-2, high=2, size=(100,3))).float()
    init_transforms = torch.concatenate([theta, trans],dim=1)
    print(init_transforms[2])

    arng = np.random.choice(np.arange(-10, 10, 1), size=(64,3))
    tid = np.random.uniform(low=-0.5, high=0.5, size=(64,3))
    arng=np.deg2rad(arng)
    arng=np.hstack([arng, np.zeros((len(arng),3))])
    tid=np.hstack([np.zeros((len(tid), 3)), tid])
    t_vec =arng[:,:]+tid[:,:]
    transform_vectors = torch.from_numpy(t_vec).float().reshape(-1,6)
    print(transform_vectors.shape)
    print(transform_vectors[0])
    transform_vector = init_transforms + transform_vectors[0].repeat(100,1)
    print(transform_vector[2])