import torch
import numpy as np



def skew_symmetric(omega):
    """
    Convert a batch of angular velocities to skew-symmetric matrices.
    omega: tensor of shape (batch_size, 3)
    Returns a tensor of shape (batch_size, 3, 3)
    """
    batch_size = omega.shape[0]
    Omega = torch.zeros((batch_size, 3, 3), device=omega.device)
    Omega[:, 0, 1] = -omega[:, 2]
    Omega[:, 0, 2] = omega[:, 1]
    Omega[:, 1, 0] = omega[:, 2]
    Omega[:, 1, 2] = -omega[:, 0]
    Omega[:, 2, 0] = -omega[:, 1]
    Omega[:, 2, 1] = omega[:, 0]
    return Omega


def omega_from_skew_symmetric(logm_matrix):
    return logm_matrix[:, [2, 0, 1], [1, 2, 0]]


def expm_SO3(Omega, dt):
    """
    Compute the matrix exponential expm(Omega * dt) for a batch of Omega matrices using Rodrigues' formula.
    Omega: tensor of shape (batch_size, 3)
    dt: scalar time step
    Returns a tensor of shape (batch_size, 3, 3)
    """
    theta = torch.norm(Omega, dim=1)[:, None, None]
    Omega_hat = skew_symmetric(Omega)
    Omega_normalized = Omega_hat / (theta + 1e-8)

    I = torch.eye(3, device=Omega.device).unsqueeze(0)  # Identity matrix of shape (1, 3, 3)
    A = Omega_normalized * torch.sin(theta * dt)
    B = torch.bmm(Omega_normalized, Omega_normalized) * (1 - torch.cos(theta * dt))

    expm_Omega_dt = I + A + B
    return expm_Omega_dt


def logm_SO3(R):
    eps = 1e-6
    trace = torch.einsum('bii->b', R)
    trace[trace > 3 - eps] = 3 - eps
    trace[trace < -1 + eps] = -1 + eps

    theta = torch.arccos((trace - 1) / 2)[:, None, None]
    logm = theta * (R - R.transpose(1, 2)) / (2 * torch.sin(theta))
    return logm


def geodesic_average_rotation_matrices(rotations, eps: float = 1e-7, max_iter: int = 100) -> np.ndarray:
    n = len(rotations)
    R = rotations[0].clone()  # Initialize with first rotation

    if len(rotations) == 1:
        return R.numpy()
    
    for _ in range(max_iter):
        logm_res = logm_SO3(torch.bmm(R.T.repeat(n, 1, 1), rotations))
        r_all = omega_from_skew_symmetric(logm_res)
        r = r_all.mean(dim=0, keepdim=True) 
        
        if torch.linalg.norm(r) < eps:
            break
            
        # Update R
        expm_res = expm_SO3(r, dt=1)[0]
        R = R @ expm_res
    
    return R.numpy()


def compute_mean_rotation(quaternions):
    """Compute the mean rotation from quaternions using vectorized averaging.
    
    Args:
        quaternions: Tensor of rotation quaternions of shape (N, 4)
        
    Returns:
        Mean quaternion of shape (4,)
    """

    if len(quaternions) == 1:
        return quaternions[0]
    
    # Convert to torch tensor if not already
    if not isinstance(quaternions, torch.Tensor):
        quaternions = torch.cat(quaternions)
        
    # Handle quaternion antipodality by ensuring all quaternions are in same hemisphere
    # Compare all to first quaternion and flip if needed
    dots = torch.sum(quaternions[0:1] * quaternions, dim=1)
    quaternions = torch.where(dots.unsqueeze(1) < 0, -quaternions, quaternions)
    
    # Simple linear average
    mean_quat = torch.mean(quaternions, dim=0)
    
    # Normalize
    mean_quat = mean_quat / torch.norm(mean_quat)
    
    return mean_quat
