import numpy as np
import torch


def matrix_to_ZYZ_angles(matrices):
    """
    Convert rotation matrices to ZYZ Euler angles.

    Args:
        matrices: torch.Tensor of shape (..., 3, 3) rotation matrices

    Returns:
        alpha, beta, gamma: ZYZ Euler angles in radians
        - alpha: first rotation around Z-axis [0, 2π]
        - beta: rotation around Y-axis [0, π]
        - gamma: second rotation around Z-axis [0, 2π]
    """
    if not isinstance(matrices, torch.Tensor):
        matrices = torch.tensor(matrices, dtype=torch.float32)

    # Ensure we're working with float tensors
    matrices = matrices.float()

    # Extract elements
    r11 = matrices[..., 0, 0]
    r12 = matrices[..., 0, 1]
    r13 = matrices[..., 0, 2]
    r21 = matrices[..., 1, 0]
    r22 = matrices[..., 1, 1]
    r23 = matrices[..., 1, 2]
    r31 = matrices[..., 2, 0]
    r32 = matrices[..., 2, 1]
    r33 = matrices[..., 2, 2]

    # Beta (middle rotation around Y)
    beta = torch.acos(torch.clamp(r33, -1.0, 1.0))

    # Handle singularities
    sin_beta = torch.sin(beta)
    is_singular = torch.abs(sin_beta) < 1e-6

    # Non-singular case
    alpha = torch.atan2(r23 / sin_beta, -r13 / sin_beta)
    gamma = torch.atan2(r32 / sin_beta, r31 / sin_beta)

    # Singular case (beta ≈ 0 or π)
    # When beta ≈ 0: alpha + gamma is determined
    # When beta ≈ π: alpha - gamma is determined
    # We set gamma = 0 and solve for alpha
    alpha_singular = torch.where(
        beta < np.pi / 2,
        torch.atan2(r21, r11),  # beta ≈ 0
        torch.atan2(-r21, -r11),  # beta ≈ π
    )
    gamma_singular = torch.zeros_like(beta)

    # Select based on singularity
    alpha = torch.where(is_singular, alpha_singular, alpha)
    gamma = torch.where(is_singular, gamma_singular, gamma)

    # Convert to numpy for matplotlib
    alpha = alpha.numpy()
    beta = beta.numpy()
    gamma = gamma.numpy()

    # Normalize angles to expected ranges
    alpha = np.mod(alpha, 2 * np.pi)  # [0, 2π]
    gamma = np.mod(gamma, 2 * np.pi)  # [0, 2π]

    return alpha, beta, gamma


def ZYZ_angles_to_matrix(alpha, beta, gamma):
    """
    Convert ZYZ Euler angles to rotation matrices.
    """
    ca, sa = torch.cos(alpha), torch.sin(alpha)
    cb, sb = torch.cos(beta), torch.sin(beta)
    cg, sg = torch.cos(gamma), torch.sin(gamma)

    # Construct each matrix
    # R_z(alpha)
    Rz_alpha = torch.stack(
        [
            torch.stack([ca, -sa, torch.zeros_like(alpha)], dim=-1),
            torch.stack([sa, ca, torch.zeros_like(alpha)], dim=-1),
            torch.stack(
                [
                    torch.zeros_like(alpha),
                    torch.zeros_like(alpha),
                    torch.ones_like(alpha),
                ],
                dim=-1,
            ),
        ],
        dim=-2,
    )  # shape [..., 3, 3]

    # R_y(beta)
    Ry_beta = torch.stack(
        [
            torch.stack([cb, torch.zeros_like(beta), sb], dim=-1),
            torch.stack(
                [
                    torch.zeros_like(beta),
                    torch.ones_like(beta),
                    torch.zeros_like(beta),
                ],
                dim=-1,
            ),
            torch.stack([-sb, torch.zeros_like(beta), cb], dim=-1),
        ],
        dim=-2,
    )

    # R_z(gamma)
    Rz_gamma = torch.stack(
        [
            torch.stack([cg, -sg, torch.zeros_like(gamma)], dim=-1),
            torch.stack([sg, cg, torch.zeros_like(gamma)], dim=-1),
            torch.stack(
                [
                    torch.zeros_like(gamma),
                    torch.zeros_like(gamma),
                    torch.ones_like(gamma),
                ],
                dim=-1,
            ),
        ],
        dim=-2,
    )

    # Final rotation: R = Rz(alpha) @ Ry(beta) @ Rz(gamma)
    return Rz_alpha @ Ry_beta @ Rz_gamma


def matrix_to_xyz_angles(matrices):
    """
    Convert rotation matrices to intrinsic xyz Euler angles.

    Args:
        matrices: torch.Tensor of shape (..., 3, 3) rotation matrices

    Returns:
        roll, pitch, yaw: xyz Euler angles in radians
        - roll: rotation around x-axis [-π, π]
        - pitch: rotation around y-axis [-π/2, π/2]
        - yaw: rotation around z-axis [-π, π]
    """
    if not isinstance(matrices, torch.Tensor):
        matrices = torch.tensor(matrices, dtype=torch.float32)

    # Ensure we're working with float tensors
    matrices = matrices.float()

    # Extract elements
    r00 = matrices[..., 0, 0]
    r01 = matrices[..., 0, 1]
    r02 = matrices[..., 0, 2]
    r10 = matrices[..., 1, 0]
    r11 = matrices[..., 1, 1]
    r12 = matrices[..., 1, 2]
    r20 = matrices[..., 2, 0]
    r21 = matrices[..., 2, 1]
    r22 = matrices[..., 2, 2]

    # Pitch (Y rotation) - middle angle with limited range
    sin_pitch = -r20
    pitch = torch.asin(torch.clamp(sin_pitch, -1.0, 1.0))

    # Handle singularities when cos(pitch) ≈ 0
    cos_pitch = torch.cos(pitch)
    is_singular = torch.abs(cos_pitch) < 1e-6

    # Non-singular case
    roll = torch.atan2(r21 / cos_pitch, r22 / cos_pitch)
    yaw = torch.atan2(r10 / cos_pitch, r00 / cos_pitch)

    # Singular case (pitch ≈ ±π/2)
    # When pitch ≈ π/2: yaw - roll is determined
    # When pitch ≈ -π/2: yaw + roll is determined
    # We set roll = 0 and solve for yaw
    yaw_singular = torch.where(
        pitch > 0,
        torch.atan2(-r12, r11),  # pitch ≈ π/2
        torch.atan2(r12, r11),  # pitch ≈ -π/2
    )
    roll_singular = torch.zeros_like(pitch)

    # Select based on singularity
    roll = torch.where(is_singular, roll_singular, roll)
    yaw = torch.where(is_singular, yaw_singular, yaw)

    return roll, pitch, yaw


def xyz_angles_to_matrix(roll, pitch, yaw):
    """
    Convert intrinsic xyz Euler angles to rotation matrices.

    Args:
        roll: x rotation angle [-π, π]
        pitch: y rotation angle [-π/2, π/2]
        yaw: z rotation angle [-π, π]

    Returns:
        Rotation matrix from R = Rz(yaw) @ Ry(pitch) @ Rx(roll)
    """
    # Ensure inputs are tensors
    if not isinstance(roll, torch.Tensor):
        roll = torch.tensor(roll, dtype=torch.float32)
    if not isinstance(pitch, torch.Tensor):
        pitch = torch.tensor(pitch, dtype=torch.float32)
    if not isinstance(yaw, torch.Tensor):
        yaw = torch.tensor(yaw, dtype=torch.float32)

    cr, sr = torch.cos(roll), torch.sin(roll)
    cp, sp = torch.cos(pitch), torch.sin(pitch)
    cy, sy = torch.cos(yaw), torch.sin(yaw)

    # Construct each rotation matrix
    # R_x(roll)
    Rx = torch.stack(
        [
            torch.stack(
                [torch.ones_like(roll), torch.zeros_like(roll), torch.zeros_like(roll)],
                dim=-1,
            ),
            torch.stack([torch.zeros_like(roll), cr, -sr], dim=-1),
            torch.stack([torch.zeros_like(roll), sr, cr], dim=-1),
        ],
        dim=-2,
    )

    # R_y(pitch)
    Ry = torch.stack(
        [
            torch.stack([cp, torch.zeros_like(pitch), sp], dim=-1),
            torch.stack(
                [
                    torch.zeros_like(pitch),
                    torch.ones_like(pitch),
                    torch.zeros_like(pitch),
                ],
                dim=-1,
            ),
            torch.stack([-sp, torch.zeros_like(pitch), cp], dim=-1),
        ],
        dim=-2,
    )

    # R_z(yaw)
    Rz = torch.stack(
        [
            torch.stack([cy, -sy, torch.zeros_like(yaw)], dim=-1),
            torch.stack([sy, cy, torch.zeros_like(yaw)], dim=-1),
            torch.stack(
                [torch.zeros_like(yaw), torch.zeros_like(yaw), torch.ones_like(yaw)],
                dim=-1,
            ),
        ],
        dim=-2,
    )

    # Final rotation for intrinsic xyz: R = Rz(yaw) @ Ry(pitch) @ Rx(roll)
    return Rz @ Ry @ Rx


def robust_blogm_so3(R, eps=1e-7):
    """
    Robust batch logarithm for SO(3) matrices using Rodrigues' formula.
    Handles near-180 degree rotations specially for numerical stability.

    Args:
        R: [..., 3, 3] rotation matrices
        eps: small value for numerical stability

    Returns:
        log_R: [..., 3, 3] skew-symmetric matrices
    """
    # Ensure we're working with 3x3 matrices
    assert R.shape[-2:] == (3, 3), f"Expected 3x3 matrices, got {R.shape}"

    batch_shape = R.shape[:-2]
    R_flat = R.reshape(-1, 3, 3)
    B = R_flat.shape[0]

    # Compute rotation angle from trace
    trace = torch.diagonal(R_flat, dim1=-2, dim2=-1).sum(-1)
    cos_theta = (trace - 1) / 2
    cos_theta = torch.clamp(cos_theta, -1 + eps, 1 - eps)
    theta = torch.acos(cos_theta)

    # Initialize output
    log_R = torch.zeros_like(R_flat)

    # Case 1: Small angles (θ < eps)
    # Use first-order approximation: log(R) ≈ (R - R^T) / 2
    small_mask = theta < eps
    if small_mask.any():
        R_small = R_flat[small_mask]
        log_R[small_mask] = (R_small - R_small.transpose(-2, -1)) / 2

    # Case 2: Regular angles (eps < θ < π - 0.2)
    # Use standard Rodrigues formula
    regular_mask = (theta >= eps) & (theta < np.pi - 0.2)
    if regular_mask.any():
        R_regular = R_flat[regular_mask]
        theta_regular = theta[regular_mask]

        # log(R) = θ/(2sin(θ)) * (R - R^T)
        factor = theta_regular / (2 * torch.sin(theta_regular) + eps)
        factor = factor.unsqueeze(-1).unsqueeze(-1)
        log_R[regular_mask] = factor * (R_regular - R_regular.transpose(-2, -1))

    # Case 3: Near 180-degree rotations (θ > π - 0.2)
    # Extract axis from symmetric part of R
    near_pi_mask = theta >= np.pi - 0.2
    if near_pi_mask.any():
        R_pi = R_flat[near_pi_mask]
        theta_pi = theta[near_pi_mask]

        # For 180° rotation: R = 2nn^T - I
        # Extract axis n from the eigenvector with eigenvalue 1
        # Faster method: use the column with largest diagonal element
        I = torch.eye(3, device=R.device, dtype=R.dtype)
        A = R_pi + I.unsqueeze(0)  # R + I = 2nn^T

        # Find column with largest diagonal element
        diag = torch.diagonal(A, dim1=-2, dim2=-1)  # [B_pi, 3]
        max_idx = torch.argmax(diag, dim=-1)  # [B_pi]

        # Extract axes
        axes = torch.zeros_like(diag)
        for i in range(len(max_idx)):
            col_idx = max_idx[i]
            col = A[i, :, col_idx]
            # Normalize to get axis
            axes[i] = col / (torch.norm(col) + eps)

            # Ensure correct sign by checking if R*n ≈ n
            if torch.dot(R_pi[i] @ axes[i], axes[i]) < 0:
                axes[i] = -axes[i]

        # Convert to skew-symmetric matrix: log(R) = θ * [n]_×
        n1, n2, n3 = axes[:, 0], axes[:, 1], axes[:, 2]
        zero = torch.zeros_like(n1)

        skew_pi = torch.stack(
            [
                torch.stack([zero, -n3, n2], dim=-1),
                torch.stack([n3, zero, -n1], dim=-1),
                torch.stack([-n2, n1, zero], dim=-1),
            ],
            dim=-2,
        )

        log_R[near_pi_mask] = theta_pi.unsqueeze(-1).unsqueeze(-1) * skew_pi

    # Reshape back to original batch shape
    log_R = log_R.reshape(*batch_shape, 3, 3)

    # Verify skew-symmetry in debug mode
    if False:  # Set to True for debugging
        skew_error = (log_R + log_R.transpose(-2, -1)).abs().max()
        if skew_error > 1e-5:
            print(f"Warning: log(R) not skew-symmetric, max error: {skew_error}")

    return log_R


def quaternion_to_matrix(q):
    """
    Convert normalized quaternion q (..., 4) to rotation matrix (..., 3, 3)
    """
    w, x, y, z = q.unbind(-1)

    ww, xx, yy, zz = w * w, x * x, y * y, z * z
    wx, wy, wz = w * x, w * y, w * z
    xy, xz, yz = x * y, x * z, y * z

    m00 = ww + xx - yy - zz
    m01 = 2 * (xy - wz)
    m02 = 2 * (xz + wy)

    m10 = 2 * (xy + wz)
    m11 = ww - xx + yy - zz
    m12 = 2 * (yz - wx)

    m20 = 2 * (xz - wy)
    m21 = 2 * (yz + wx)
    m22 = ww - xx - yy + zz

    return torch.stack(
        [
            torch.stack([m00, m01, m02], dim=-1),
            torch.stack([m10, m11, m12], dim=-1),
            torch.stack([m20, m21, m22], dim=-1),
        ],
        dim=-2,
    )
