import torch

def qmul(q, r):
    """
    Multiply quaternion(s) q with quaternion(s) r.
    Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions.
    Returns q*r as a tensor of shape (*, 4).
    """
    assert q.shape[-1] == 4
    assert r.shape[-1] == 4

    original_shape = q.shape

    # Compute outer product
    terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4))

    w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3]
    x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2]
    y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1]
    z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0]
    return torch.stack((w, x, y, z), dim=1).view(original_shape)

def qinv(q):
    assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
    mask = torch.ones_like(q)
    mask[..., 1:] = -mask[..., 1:]
    return q * mask

def qnormalize(q):
    assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
    return q / torch.norm(q, dim=-1, keepdim=True)

def qbetween(v0, v1):
    '''
    find the quaternion used to rotate v0 to v1
    '''
    assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
    assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'

    v = torch.cross(v0, v1)
    w = torch.sqrt((v0 ** 2).sum(dim=-1, keepdim=True) * (v1 ** 2).sum(dim=-1, keepdim=True)) + (v0 * v1).sum(dim=-1, keepdim=True)
    return qnormalize(torch.cat([w, v], dim=-1))

def qbetween_safe(v0: torch.Tensor, v1: torch.Tensor) -> torch.Tensor:
    """
    Compute quaternion rotating v0 to v1, safely handling the opposite direction case.
    v0, v1: (..., 3)
    Return: (..., 4) quaternion
    """
    v0 = v0 / v0.norm(dim=-1, keepdim=True)
    v1 = v1 / v1.norm(dim=-1, keepdim=True)

    dot = (v0 * v1).sum(dim=-1, keepdim=True)  # (..., 1)
    same_direction = dot > 0.999
    opposite_direction = dot < -0.999

    # Normal case
    v = torch.cross(v0, v1, dim=-1)
    w = torch.sqrt((v0 ** 2).sum(dim=-1, keepdim=True) * (v1 ** 2).sum(dim=-1, keepdim=True)) + dot
    quat = torch.cat([w, v], dim=-1)
    quat = quat / quat.norm(dim=-1, keepdim=True)

    # Opposite direction case
    # For each v0, pick a stable orthogonal axis
    up1 = torch.tensor([1.0, 0.0, 0.0], dtype=v0.dtype, device=v0.device).expand_as(v0)
    up2 = torch.tensor([0.0, 1.0, 0.0], dtype=v0.dtype, device=v0.device).expand_as(v0)

    use_up1 = (torch.abs((v0 * up1).sum(dim=-1)) < 0.9).unsqueeze(-1)
    ortho_ref = torch.where(use_up1, up1, up2)
    axis = torch.cross(v0, ortho_ref, dim=-1)
    axis = axis / axis.norm(dim=-1, keepdim=True)

    quat_opposite = torch.cat([torch.zeros_like(dot), axis], dim=-1)
    quat_same = torch.zeros_like(quat)
    quat_same[..., 0] = 1.0

    quat = torch.where(same_direction.expand_as(quat), quat_same, quat)
    quat = torch.where(opposite_direction.expand_as(quat), quat_opposite, quat)
    return quat

def qrot(q, v):
    """
    Rotate vector(s) v about the rotation described by quaternion(s) q.
    Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
    where * denotes any number of dimensions.
    Returns a tensor of shape (*, 3).
    """
    assert q.shape[-1] == 4
    assert v.shape[-1] == 3
    assert q.shape[:-1] == v.shape[:-1]

    original_shape = list(v.shape)
    # print(q.shape)
    q = q.contiguous().view(-1, 4)
    v = v.contiguous().view(-1, 3)

    qvec = q[:, 1:]
    uv = torch.cross(qvec, v, dim=1)
    uuv = torch.cross(qvec, uv, dim=1)
    return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)

def rigid_transform_tensor(relative, data):
    B, T, *_ = data.shape
    global_positions = data[..., :22*3].reshape(*data.shape[:-1], 22, 3)  # [*, 22, 3]
    global_vel = data[..., 22*3:22*6].reshape(*data.shape[:-1], 22, 3)    # [*, 22, 3]

    relative_rot = relative[:,0]  # scalar tensor
    cos_r = torch.cos(relative_rot).view(B, 1, 1)
    sin_r = torch.sin(relative_rot).view(B, 1, 1)
    q = torch.zeros(*global_positions.shape[:-1], 4, device=global_positions.device, dtype=global_positions.dtype)  # [..., 4]
    q[..., 0] = cos_r       
    q[..., 2] = sin_r       

    q_inv = qinv(q)
    global_positions = qrot(q_inv, global_positions)  # [*, 22, 3]
    global_vel = qrot(q_inv, global_vel)

    relative_t = relative[:,1:3]  # [2] (x, z)
    global_positions[:, :, :, 0] += relative_t[:, None, None, 0]  # x
    global_positions[:, :, :, 2] += relative_t[:, None, None, 1]  # z

    data[..., :22*3] = global_positions.reshape(*data.shape[:-1], 22*3)
    data[..., 22*3:22*6] = global_vel.reshape(*data.shape[:-1], 22*3)

    return data

