import torch
from torch.func import jacrev


def tr_loss(v_value, x0, x1):
    """
    Calculate the translation loss as the sum of squares of the difference between v_value and (x1 - x0).

    Args:
        v_value (Tensor): The tensor representing the value of v(x(t), t).
        x0 (Tensor): The initial translation tensor.
        x1 (Tensor): The target translation tensor.

    Returns:
        Tensor: The translation loss, computed as the sum of squares of the differences.
    """
    loss = (v_value + x0 - x1).square().sum(1)
    return loss


def tor_loss(v_value, theta0, theta1, bond_periods):
    """
    Calculate the torsion loss as the square of the difference between v_value, theta0, and theta1.

    Args:
        v_value (Tensor): The tensor representing the torsion value of v(x(t), t).
        theta0 (Tensor): The initial torsion tensor.
        theta1 (Tensor): The target torsion tensor.
        bond_periods (Tensor): The bond periods.
    Returns:
        Tensor: The torsion loss, computed as the square of the differences.
    """
    v_value = v_value.ravel()
    theta0 = theta0.ravel()
    theta1 = theta1.ravel()

    diff = (theta1 - theta0 + bond_periods / 2) % bond_periods - bond_periods / 2
    loss = 2 * (v_value - diff) ** 2 # need to average over the bond dimension
    return loss


def batched_tor_loss(v_value, theta0, theta1, tor_is_padded_mask, bond_periods):
    """
    Calculate the torsion loss as the square of the difference between v_value, theta0, and theta1.

    Args:
        v_value (Tensor): The tensor representing the torsion value of v(x(t), t).
        theta0 (Tensor): The initial torsion tensor.
        theta1 (Tensor): The target torsion tensor.
        is_padded_mask (Tensor): The tensor representing the padded mask.
    """
    diff = (theta1 - theta0 + bond_periods / 2) % bond_periods - bond_periods / 2
    loss = 2 * (v_value - diff) ** 2
    loss[tor_is_padded_mask] = 0
    num_nonzero_tors = tor_is_padded_mask.shape[1] - tor_is_padded_mask.sum(1)
    num_nonzero_tors = torch.where(num_nonzero_tors == 0, 1, num_nonzero_tors)
    loss = loss.sum(1) / num_nonzero_tors
    loss = loss.mean()
    return loss


def rot_loss(v_value, tangents):
    """
    Calculate the rotation loss as the sum of squares of the difference between v_value and tangents.

    Args:
        v_value (Tensor): The tensor representing the rotation value of v(x(t), t).
        tangents (Tensor): The tensor representing the tangents.

    Returns:
        Tensor: The rotation loss, computed as the sum of squares of the differences.
    """
    if v_value.shape != tangents.shape:
        raise ValueError('v_values.shape != tangents.shape')
    if len(v_value.shape) != 2:
        raise ValueError('len(v_value.shape) != 2')
    loss = 2 * (v_value - tangents).square().sum(1)
    return loss


def interpolate_tr(x0, x1, t):
    """
    Interpolate between x0 and x1 based on t for translation. This function performs a linear interpolation
    between two translation tensors.

    Args:
        x0 (Tensor): The initial translation tensor of shape (N, D), where N is the batch size and D is
                     the dimension of the translation vector.
        x1 (Tensor): The target translation tensor of shape (N, D), where N is the batch size and D is
                     the dimension of the translation vector.
        t (Tensor): The interpolation parameter of shape (N,), where each element is between 0 and 1.

    Returns:
        Tensor: The interpolated tensor of shape (N, D), where each row represents the interpolated
                translation vector for the corresponding elements in x0 and x1.
    """
    return x0 + t[:, None] * (x1 - x0)


def interpolate_tor(theta0, theta1, t, num_tors_per_sample_arr, bond_periods):
    """
    Interpolate between theta0 and theta1 based on t for torsion.

    Args:
        theta0 (Tensor): The initial torsion tensor of size (M, ), where M is the quanity of rotatable bonds
                        in batch
        theta1 (Tensor): The target torsion tensor of size (M, ), where M is the quanity of rotatable bonds
                        in batch
        t (Tensor): The tensor of time values of size (N, ), where N is the batch size
        num_tors_per_sample_arr (Tensor): The tensor of size (N, ), where N is the batch size.
                        Contains the number of rotatable bonds per sample in the batch.
        bond_periods (Tensor): The bond periods (default is 2*pi, but can be smaller for bonds with symmetric parts).

    Returns:
        Tensor: The interpolated tensor.
    """
    if len(theta0.shape) > 1 or len(theta1.shape) > 1 or len(t.shape) > 1 or len(num_tors_per_sample_arr.shape) > 1:
        raise ValueError("Input arrays must be one-dimensional!")
    diff = theta1 - theta0
    idx = diff > bond_periods / 2
    diff[idx] -= bond_periods[idx]
    idx = diff < -bond_periods / 2
    diff[idx] += bond_periods[idx]
    t = torch.repeat_interleave(t, num_tors_per_sample_arr)
    return theta0 + t * diff


def interpolate_tor_new(theta0, theta1, t):
    """
    Interpolate between theta0 and theta1 based on t for torsion.

    Args:
        theta0 (Tensor): The initial torsion tensor of size (batch_size, max_num_tors)
        theta1 (Tensor): The target torsion tensor of size (batch_size, max_num_tors)
        t (Tensor): The tensor of time values of size (batch_size, )

    Returns:
        Tensor: The interpolated tensor, shape (batch_size, max_num_tors).
    """
    diff = theta1 - theta0
    idx = diff > torch.pi
    diff[idx] -= 2 * torch.pi
    idx = diff < -torch.pi
    diff[idx] += 2 * torch.pi
    t = t.view(-1, 1)
    return theta0 + t * diff


# @torch.jit.script
def slerp_and_derivative(q0, q1, t):
    """
    Perform spherical linear interpolation (SLERP) between two quaternions and compute its derivative.
    Handles batching for multiple sets of quaternions.

    Args:
        q0 (Tensor): The initial quaternion tensor of shape (B, 4), where B is the batch size.
        q1 (Tensor): The target quaternion tensor of shape (B, 4).
        t (Tensor): The interpolation parameter of shape (B,).

    Returns:
        Tuple[Tensor, Tensor]: A tuple containing the slerp quaternion and its derivative.
    """
    EPS = 1e-6
    t = t.view(-1, 1)
    dot_product = torch.sum(q0*q1, dim=-1).view(-1, 1)
    q0[dot_product[:,0] < 0] = -q0[dot_product[:,0]<0]
    dot_product = dot_product * torch.sign(dot_product)
    theta_0 = torch.acos(torch.clamp(dot_product, -1, 1))  # initial angle
    sin_theta_0 = torch.sin(theta_0)

    # If sin(theta_0) is zero (q0 and q1 are parallel), use linear interpolation
    slerp_t = torch.where(
        sin_theta_0 > EPS,
        (torch.sin((1.0 - t) * theta_0) * q0 + torch.sin(t * theta_0) * q1) / sin_theta_0,
        (1.0 - t) * q0 + t * q1
    )

    # Compute derivative
    slerp_derivative_t = torch.where(
        sin_theta_0 > EPS,
        (-torch.cos((1.0 - t) * theta_0) * q0 + torch.cos(t * theta_0) * q1) * theta_0 / sin_theta_0,
        q1 - q0
    )
    # assert torch.nonzero(torch.isnan(slerp_t.view(-1))).sum() == 0
    # assert torch.nonzero(torch.isnan(slerp_derivative_t.view(-1))).sum() == 0
    return slerp_t, slerp_derivative_t


def q_to_rotmat(q):
    """
    Convert a unit quaternion to a rotation matrix.

    Args:
        q (Tensor): The quaternion tensor of shape (B, 4), where B is the batch size.

    Returns:
        Tensor: The corresponding rotation matrix of shape (B, 3, 3).
    """
    # Normalize quaternion
    q = q / torch.linalg.norm(q, dim=1)[:, None]
    
    x = q[:, 0]
    y = q[:, 1]
    z = q[:, 2]
    w = q[:, 3]

    x2 = x * x
    y2 = y * y
    z2 = z * z
    w2 = w * w

    xy = x * y
    zw = z * w
    xz = x * z
    yw = y * w
    yz = y * z
    xw = x * w

    batch_size = q.shape[0]

    # Compute each row of the rotation matrix
    matrix = torch.stack(
        [x2 - y2 - z2 + w2, 2 * (xy - zw), 2 * (xz + yw),
        2 * (xy + zw), - x2 + y2 - z2 + w2, 2 * (yz - xw),
        2 * (xz - yw), 2 * (yz + xw), - x2 - y2 + z2 + w2],
    dim=1).reshape(batch_size, 3, 3)

    return matrix


def rotmat_to_q(matrix):
    """
    Convert a rotation matrix to a unit quaternion representation.

    Args:
        matrix (Tensor): The rotation matrix tensor of shape (..., 3, 3).

    Returns:
        Tensor: The corresponding quaternion tensor of shape (..., 4), following the XYZW convention.
    """
    # Taken from https://github.com/naver/roma/blob/ef10b1e4006c00cf8ee3bc7480994fb274d18841/roma/mappings.py#L328

    num_rotations = matrix.shape[0]
    decision_matrix = torch.empty((num_rotations, 4), dtype=matrix.dtype, device=matrix.device)
    decision_matrix[:, :3] = matrix.diagonal(dim1=1, dim2=2)
    decision_matrix[:, -1] = decision_matrix[:, :3].sum(axis=1)
    choices = decision_matrix.argmax(axis=1)

    quat = torch.empty((num_rotations, 4), dtype=matrix.dtype, device=matrix.device)

    ind = torch.nonzero(choices != 3, as_tuple=True)[0]
    i = choices[ind]
    j = (i + 1) % 3
    k = (j + 1) % 3

    quat[ind, i] = 1 - decision_matrix[ind, -1] + 2 * matrix[ind, i, i]
    quat[ind, j] = matrix[ind, j, i] + matrix[ind, i, j]
    quat[ind, k] = matrix[ind, k, i] + matrix[ind, i, k]
    quat[ind, 3] = matrix[ind, k, j] - matrix[ind, j, k]

    ind = torch.nonzero(choices == 3, as_tuple=True)[0]
    quat[ind, 0] = matrix[ind, 2, 1] - matrix[ind, 1, 2]
    quat[ind, 1] = matrix[ind, 0, 2] - matrix[ind, 2, 0]
    quat[ind, 2] = matrix[ind, 1, 0] - matrix[ind, 0, 1]
    quat[ind, 3] = 1 + decision_matrix[ind, -1]

    quat = quat / torch.norm(quat, dim=1)[:, None]
    return quat


def jac_q_to_rotmat(q):
    """
    Compute the Jacobian of the quaternion to rotation matrix conversion.

    Args:
        q (Tensor): The quaternion tensor.

    Returns:
        Tensor: The Jacobian of the quaternion to rotation matrix conversion.
    """
    return torch.vmap(lambda q: jacrev(lambda q: q_to_rotmat(q[None, :])[0])(q))(q)


def interpolate_rot_and_tangents(R0, R1, t):
    """
    Performs spherical linear interpolation (SLERP) between two rotation matrices and computes tangents.

    Args:
        R0 (Tensor): The initial rotation matrix of shape (batch_size, 3, 3).
        R1 (Tensor): The target rotation matrix of shape (batch_size, 3, 3).
        t (Tensor): The interpolation parameter of shape (batch_size,).

    Returns:
        Tuple[Tensor, Tensor]: A tuple containing:
                                - Rt: The interpolated rotation matrices of shape (batch_size, 3, 3).
                                - tangents: The tangents, representing the skew-symmetric matrix components
                                  of shape (batch_size, 3).
    """
    q0 = rotmat_to_q(R0)
    q1 = rotmat_to_q(R1)
    qt, dqdt = slerp_and_derivative(q0, q1, t)
    Rt = q_to_rotmat(qt)
    dRdt = torch.einsum('bijk,bk->bij', jac_q_to_rotmat(qt), dqdt)
    skewsym = torch.einsum('bki,bkj->bij', Rt, dRdt)
    x = skewsym[:, 2, 1]
    y = skewsym[:, 0, 2]
    z = skewsym[:, 1, 0]
    tangents = torch.stack([x, y, z], axis=-1)
    return Rt, tangents


def hat_map(batch):
    """
    Converts a batch of vectors in R^3 to a batch of skew-symmetric matrices (so(3)).

    Args:
        batch (Tensor): A batch of vectors in R^3 of shape (batch_size, 3).

    Returns:
        Tensor: A batch of skew-symmetric matrices of shape (batch_size, 3, 3).
    """
    zero = torch.zeros(batch.shape[0], device=batch.device, dtype=batch.dtype)
    skew_matrices = torch.stack([
        torch.stack([zero, -batch[:, 2], batch[:, 1]], dim=1),
        torch.stack([batch[:, 2], zero, -batch[:, 0]], dim=1),
        torch.stack([-batch[:, 1], batch[:, 0], zero], dim=1)
    ], dim=1)
    return skew_matrices


def so3_to_SO3(batch):
    """
    Projects a batch of so(3) elements (skew-symmetric matrices) onto SO(3) using Rodrigues' formula.

    Args:
        batch (Tensor): A batch of so(3) elements of shape (batch_size, 3, 3).

    Returns:
        Tensor: A batch of SO(3) elements of shape (batch_size, 3, 3).
    """
    omega_hat = hat_map(batch)
    theta = torch.norm(batch, dim=1, keepdim=True)

    identity_matrices = torch.eye(3, device=batch.device, dtype=batch.dtype).unsqueeze(0)

    A = torch.sin(theta) / theta
    B = (1 - torch.cos(theta)) / (theta ** 2)

    # Adjusting for very small theta values to avoid division by zero
    A[theta < 1e-6] = 1
    B[theta < 1e-6] = 0.5

    A = A.view(-1, 1, 1)
    B = B.view(-1, 1, 1)

    R = identity_matrices + A * omega_hat + B * torch.bmm(omega_hat, omega_hat)
    return R


def so2_to_SO2(theta):
    """
    Projects a batch of so(3) elements (represented as theta) onto SO(2) as scalar values.

    Args:
        theta (Tensor): A batch of so(3) elements of shape (batch_size,).

    Returns:
        Tensor: A batch of SO(2) elements of shape (batch_size,), clamped between -π and π.
    """
    return torch.clamp(theta, min=-torch.pi, max=torch.pi)
