import numpy as np
import copy
from scipy.spatial.transform import Rotation as R
import torch

from rdkit.Geometry import Point3D



def compute_batch_ligand_centers(batch):
    """
    Compute the mean positions of ligands in a batch.

    Parameters:
    ----------
    batch : ComplexBatch

    Returns:
    -------
    torch.Tensor
        Mean positions of the ligands in the batch. Shape: (batch_size, 3)
    """
    # We assume that zero-padding is maintained
    ligand_centers = batch.ligand.pos.sum(axis=1) / batch.ligand.num_atoms[:, None]
    return ligand_centers


def compute_batch_rmsds(batch):
    """
    Compute RMSDs of ligands in a batch.

    Parameters:
    ----------
    batch : ComplexBatch

    Returns:
    -------
    torch.Tensor
        RMSDs of the ligands in the batch. Shape: (batch_size, )
    """
    # We assume that zero-padding is maintained
    rmsds = torch.sqrt(((batch.ligand.pos - \
                         batch.ligand.orig_pos) ** 2).sum(axis=2).sum(axis=1) / batch.ligand.num_atoms)
    return rmsds


def compute_per_atom_rmsds(batch):
    all_rmsds = torch.linalg.norm(batch.ligand.pos - batch.ligand.orig_pos, dim=-1)
    return all_rmsds[~batch.ligand.is_padded_mask]


def rotvec_to_rotmat(rotvec):
    """
    Converts a batch of rotation vectors to rotation matrices using the Rodrigues' rotation formula.

    Args:
        rotvec (torch.Tensor): A tensor of shape (batch_size, 3) representing a batch of rotation vectors.

    Returns:
        torch.Tensor: A tensor of shape (batch_size, 3, 3) representing a batch of rotation matrices.
    """
    # Compute the norm (theta) for each rotation vector
    theta = torch.norm(rotvec, dim=-1, keepdim=True)  # Shape: (batch_size, 1)
    
    # To avoid division by zero, replace 0 with a small value in theta
    epsilon = 1e-6
#     theta = torch.where(theta < epsilon, torch.tensor(epsilon, device=rotvec.device), theta)

    # Compute the normalized rotation vectors (n = rotvec / theta)
#     n = rotvec / theta  # Shape: (batch_size, 3)
    n = torch.nn.functional.normalize(rotvec, dim=-1)

    # Extract n1, n2, n3 from normalized vectors
    n1, n2, n3 = n[0], n[1], n[2]

    # Precompute trigonometric terms
    cos_theta = torch.cos(theta)
    sin_theta = torch.sin(theta)
    one_minus_cos_theta = 1 - cos_theta

    # Compute the rotation matrix using the formula
    r11 = cos_theta + n1**2 * one_minus_cos_theta
    r12 = n1 * n2 * one_minus_cos_theta - n3 * sin_theta
    r13 = n1 * n3 * one_minus_cos_theta + n2 * sin_theta

    r21 = n1 * n2 * one_minus_cos_theta + n3 * sin_theta
    r22 = cos_theta + n2**2 * one_minus_cos_theta
    r23 = n2 * n3 * one_minus_cos_theta - n1 * sin_theta

    r31 = n1 * n3 * one_minus_cos_theta - n2 * sin_theta
    r32 = n2 * n3 * one_minus_cos_theta + n1 * sin_theta
    r33 = cos_theta + n3**2 * one_minus_cos_theta

    # Stack the components to form the rotation matrix
    rotation_matrix = torch.stack([r11, r12, r13, r21, r22, r23, r31, r32, r33], dim=-1)
    rotation_matrix = rotation_matrix.view(3, 3)  # Reshape to (batch_size, 3, 3)

    return rotation_matrix


def apply_tor_changes_to_pos(pos, rotatable_bonds, mask_rotate, torsion_updates, is_reverse_order, 
                             bond_properties_for_angles=None, shift_center_back=True, 
                             fix_moments_of_inertia=False):
    """
    Apply torsion updates to the positions of atoms in a sample in-place.

    Parameters:
    ----------
    pos : Union[np.ndarray, torch.Tensor]
        The positions of atoms in the sample, shape (num_atoms, 3).
    rotatable_bonds : Union[np.ndarray, torch.Tensor]
        Rotatable bonds in the sample, shape (num_rotatable_bonds, 2). Each bond is represented by
        two indices: (atom1, atom2).
    mask_rotate : Union[np.ndarray, torch.Tensor]
        Mask indicating which atoms to rotate for each bond, shape (num_rotatable_bonds, num_atoms).
    torsion_updates : Union[np.ndarray, torch.Tensor]
        Torsion updates to apply to each rotatable bond, shape (num_rotatable_bonds,).

    Returns:
    -------
    Union[np.ndarray, torch.Tensor]
        The updated positions of atoms in the sample.
    """
    is_torch = isinstance(pos, torch.Tensor)

    if len(rotatable_bonds) == 0:
        return pos, None

    if rotatable_bonds.shape[1] != 2:
        raise ValueError('A wrong format of rotational bonds array!')

    num_rotatable_bonds = rotatable_bonds.shape[0]

    if is_reverse_order:
        range_for_rot_bonds = range(num_rotatable_bonds - 1, -1, -1)
    else:
        range_for_rot_bonds = range(num_rotatable_bonds)

    # compute initial ligand center
    pos_mean = pos.mean(0)[None, :]
    if fix_moments_of_inertia:
        if is_torch:
            init_pos_to_align = pos.clone()
        else:
            init_pos_to_align = np.copy(pos)

    for idx_rot_bond in range_for_rot_bonds:
        u = rotatable_bonds[idx_rot_bond, 0]
        v = rotatable_bonds[idx_rot_bond, 1]
        rot_vec = pos[u] - pos[v]  # convention: positive rotation if pointing inwards

        if is_torch:
            # Rotate v:
            rot_vec = rot_vec * torsion_updates[idx_rot_bond] / torch.linalg.norm(rot_vec)
            rot_mat = rotvec_to_rotmat(rot_vec)
            mask = mask_rotate[idx_rot_bond].bool()
            pos[mask] = (pos[mask] - pos[v]) @ rot_mat.T + pos[v]
        else:
            # Rotate v:
            rot_vec = rot_vec * torsion_updates[idx_rot_bond] / np.linalg.norm(rot_vec)
            rot_mat = rotvec_to_rotmat(torch.tensor(rot_vec, dtype=torch.float)).numpy()
            mask = mask_rotate[idx_rot_bond].astype(bool)
            pos[mask] = (pos[mask] - pos[v]) @ rot_mat.T + pos[v]

    # shift to the initial center
    if shift_center_back:
        pos = pos - pos.mean(0)[None, :] + pos_mean

    # if bond_properties_for_angles is not None:
    #     angles_before = get_torsion_angles(pos.clone(), copy.deepcopy(bond_properties_for_angles)).clone()

    rot = None
    if fix_moments_of_inertia:
        rot, tr = find_rigid_alignment(pos, init_pos_to_align)
        pos = (pos - pos.mean(0)) @ rot.T + tr

    angles = None
    if bond_properties_for_angles is not None:
        angles = get_torsion_angles(pos.clone(), copy.deepcopy(bond_properties_for_angles)).clone()

    return pos, angles


def apply_tor_changes_to_pos_ext(pos, start_indices, end_indices, mask_rotate, torsion_updates, 
                                 is_reverse_order, bond_properties_for_angles=None, shift_center_back=True):
    """
    Apply torsion updates to the positions of atoms in a sample in-place.

    Parameters:
    ----------
    pos : torch.Tensor
        The positions of atoms in the sample, shape (num_atoms, 3).
    start_indices : torch.Tensor
        Start indices of rotatable bonds in the sample, shape (num_rotatable_bonds,).
    end_indices : torch.Tensor
        End indices of rotatable bonds in the sample, shape (num_rotatable_bonds,).
    mask_rotate : torch.Tensor
        Mask indicating which atoms to rotate for each bond, shape (num_rotatable_bonds, num_atoms).
    torsion_updates : torch.Tensor
        Torsion updates to apply to each rotatable bond, shape (num_rotatable_bonds,).

    Returns:
    -------
    pos : torch.Tensor
        The updated positions of atoms in the sample.
    angles : torch.Tensor
        The angles of the rotatable bonds.
    """

    is_torch = isinstance(pos, torch.Tensor)

    if len(start_indices) == 0:
        return pos, None

    num_rotatable_bonds = start_indices.shape[0]

    if is_reverse_order:
        range_for_rot_bonds = range(num_rotatable_bonds - 1, -1, -1)
    else:
        range_for_rot_bonds = range(num_rotatable_bonds)

    # compute initial ligand center
    pos_mean = pos.mean(0)[None, :]

    for idx_rot_bond in range_for_rot_bonds:
        try:
            u = start_indices[idx_rot_bond]
            v = end_indices[idx_rot_bond]
        except:
            import pdb; pdb.set_trace()
        rot_vec = pos[u] - pos[v]  # convention: positive rotation if pointing inwards
        # Rotate v:
        if is_torch:
            rot_vec_norm = torch.linalg.norm(rot_vec)
        else:
            rot_vec_norm = np.linalg.norm(rot_vec)
        rot_vec = rot_vec * torsion_updates[idx_rot_bond] / rot_vec_norm
        mask = mask_rotate[idx_rot_bond]
        if is_torch:
            rot_mat = rotvec_to_rotmat(rot_vec)
        else:
            rot_mat = np.array(R.from_rotvec(rot_vec).as_matrix())
        pos[mask] = (pos[mask] - pos[v]) @ rot_mat.T + pos[v]

    # shift to the initial center
    if shift_center_back:
        pos = pos - pos.mean(0)[None, :] + pos_mean

    angles = None
    if bond_properties_for_angles is not None and num_rotatable_bonds > 0:
        angles = get_torsion_angles(pos, bond_properties_for_angles)

    return pos, angles


def apply_tor_changes_to_batch_inplace(batch, tor, is_reverse_order):
    """
    Apply torsion updates to each ligand in the batch.

    Parameters:
    ----------
    batch : Batch
        The batch containing ligand information.
    tor : np.ndarray
        Torsion updates to apply to each rotatable bond, shape (num_rotatable_bonds,).

    Returns:
    -------
    Batch
        The batch with updated positions of atoms.
    """
    left_rot_bond_idx = 0
    # TODO: vectorize
    for idx, mask_rotate in enumerate(batch.ligand.mask_rotate):
        pos = batch.ligand.pos[idx, :batch.ligand.num_atoms[idx], :]
        right_rot_bond_idx = left_rot_bond_idx + batch.ligand.num_rotatable_bonds[idx]
        rotatable_bonds = batch.ligand.rotatable_bonds[left_rot_bond_idx:right_rot_bond_idx]
        torsion_updates = tor[left_rot_bond_idx:right_rot_bond_idx]
        left_rot_bond_idx = right_rot_bond_idx
        bond_properties_for_angles = None
        pos, angles = apply_tor_changes_to_pos(pos, rotatable_bonds, mask_rotate, torsion_updates,
                                               is_reverse_order=is_reverse_order, 
                                               bond_properties_for_angles=bond_properties_for_angles)
        batch.ligand.pos[idx, :batch.ligand.num_atoms[idx], :] = pos
        if angles is not None:
            batch.ligand.rotatable_bonds_ext.angles[idx, :batch.ligand.num_rotatable_bonds[idx]] = angles
    return


def apply_tor_changes_to_batch_inplace_ext(batch, tor, is_reverse_order):
    """
    Apply torsion updates to each ligand in the batch.

    Parameters:
    ----------
    batch : Batch
        The batch containing ligand information.
    tor : torch.Tensor of shape (batch_size, max_num_rot_bonds)
        Torsion updates to apply to each rotatable bond.

    Returns:
    -------
    Batch
        The batch with updated positions of atoms.
    """
    for idx, num_atoms in enumerate(batch.ligand.num_atoms):
        num_rot_bonds = batch.ligand.rotatable_bonds_ext.num_rotatable_bonds[idx]
        pos = batch.ligand.pos[idx, :num_atoms, :]
        torsion_updates = tor[idx, :num_rot_bonds]
        mask_rotate = batch.ligand.rotatable_bonds_ext.mask_rotate[idx, :num_rot_bonds, :num_atoms]
        pos, angles = apply_tor_changes_to_pos_ext(
            pos,
            batch.ligand.rotatable_bonds_ext.start[idx, :num_rot_bonds],
            batch.ligand.rotatable_bonds_ext.end[idx, :num_rot_bonds],
            mask_rotate[:, :num_atoms],
            torsion_updates,
            bond_properties_for_angles='rotatable_bonds_ext', # TODO implement
            is_reverse_order=is_reverse_order
        )
        batch.ligand.pos[idx, :num_atoms, :] = pos
        if angles is not None:
            batch.ligand.rotatable_bonds_ext.angles[idx, :num_rot_bonds] = angles

    return


def apply_tr_rot_changes_to_batch_inplace(batch, tr, rot):
    batch_size = tr.shape[0]
    pos_mean = compute_batch_ligand_centers(batch)
    '''
    Here we do not add pos_mean, because tr is the new center of mass!
    So, new_pos = (pos - pos_mean) @ rot.T + tr
    '''
    batch.ligand.pos[:] = torch.einsum('bij,bkj->bik', batch.ligand.pos - pos_mean[:, None, :], 
                                       rot) + tr.reshape(batch_size, 1, 3)
    # TODO: vectorize
    for batch_idx, num_atoms in enumerate(batch.ligand.num_atoms):
        batch.ligand.pos[batch_idx, num_atoms:] = 0.

    return


def apply_changes_to_batch_inplace(batch, tr, rot, tor, is_reverse_order):
    apply_tr_rot_changes_to_batch_inplace(batch, tr, rot)
    apply_tor_changes_to_batch_inplace(batch, tor, is_reverse_order=is_reverse_order)
    return


def apply_changes_to_batch_inplace_ext(batch, tr, rot, tor, is_reverse_order):
    apply_tor_changes_to_batch_inplace_ext(batch, tor, is_reverse_order=is_reverse_order)
    apply_tr_rot_changes_to_batch_inplace(batch, tr, rot)
    return


def normalize_angle_rad(angle):
    """
    Normalize angles to be within [-π, π].

    Parameters:
    ----------
    angle : float, np.ndarray, or torch.Tensor
        Angle or array of angles in radians to normalize.

    Returns:
    -------
    float, np.ndarray, or torch.Tensor
        Normalized angle(s) within the range [-π, π].

    Raises:
    ------
    ValueError
        If the input type is not supported.
    """
    if isinstance(angle, float):
        angle = angle % (2 * np.pi)
        if angle > np.pi:
            angle -= 2 * np.pi
    elif isinstance(angle, np.ndarray) or isinstance(angle, torch.Tensor):
        angle = angle % (2 * np.pi)
        idx = angle > np.pi
        angle[idx] -= 2 * np.pi
    else:
        raise ValueError('Unsupported type')
    return angle


def compute_signed_dihedral_angles_vectorized(positions):
    """
    Compute signed dihedral angles for multiple sets of atoms in a vectorized way.
    
    Parameters:
    -----------
    positions: np.ndarray or torch.Tensor, shape (B, 4, 3)
        Positions of B sets of 4 atoms each: (n0, start, end, n1) in 3D space.
        
    Returns:
    --------
    angles: np.ndarray or torch.Tensor, shape (B,)
        The signed dihedral angles in radians.
    """
    # Extract the positions of each atom in each set
    p0 = positions[:, 0]  # n0 positions
    p1 = positions[:, 1]  # start positions
    p2 = positions[:, 2]  # end positions
    p3 = positions[:, 3]  # n1 positions
    
    # Calculate bond vectors
    b1 = p1 - p0  # n0 -> start
    b2 = p2 - p1  # start -> end
    b3 = p3 - p2  # end -> n1
    
    # Normalize b2
    if isinstance(positions, torch.Tensor):
        b2_normalized = b2 / torch.norm(b2, dim=1, keepdim=True)
        
        # Calculate normal vectors to the planes
        n1 = torch.cross(b1, b2, dim=1)
        n1_normalized = n1 / torch.norm(n1, dim=1, keepdim=True)
        
        n2 = torch.cross(b2, b3, dim=1)
        n2_normalized = n2 / torch.norm(n2, dim=1, keepdim=True)
        
        # Calculate the orthogonal vector to n1 in the plane defined by b2
        m1 = torch.cross(n1_normalized, b2_normalized, dim=1)
        
        # Calculate cosine and sine
        x = torch.sum(n1_normalized * n2_normalized, dim=1)
        y = torch.sum(m1 * n2_normalized, dim=1)
        
        # Calculate dihedral angle
        angles = torch.atan2(y, x)
    else:
        b2_normalized = b2 / np.linalg.norm(b2, axis=1, keepdims=True)
        
        # Calculate normal vectors to the planes
        n1 = np.cross(b1, b2)
        n1_normalized = n1 / np.linalg.norm(n1, axis=1, keepdims=True)
        
        n2 = np.cross(b2, b3)
        n2_normalized = n2 / np.linalg.norm(n2, axis=1, keepdims=True)
        
        # Calculate the orthogonal vector to n1 in the plane defined by b2
        m1 = np.cross(n1_normalized, b2_normalized)
        
        # Calculate cosine and sine
        x = np.sum(n1_normalized * n2_normalized, axis=1)
        y = np.sum(m1 * n2_normalized, axis=1)
        
        # Calculate dihedral angle
        angles = np.arctan2(y, x)
    
    return -angles


def get_torsion_angles(pos, bond_atoms_for_angles):
    # Create a batch of all atom quartets for dihedral calculations
    n0 = bond_atoms_for_angles['neighbor_of_start']
    start = bond_atoms_for_angles['start']
    end = bond_atoms_for_angles['end']
    n1 = bond_atoms_for_angles['neighbor_of_end']
    
    # Stack the positions to form a batch
    if isinstance(pos, torch.Tensor):
        atom_quartets = torch.stack([
            pos[n0], pos[start], pos[end], pos[n1]
        ], dim=1)
    else:
        atom_quartets = np.stack([
            pos[n0], pos[start], pos[end], pos[n1]
        ], axis=1)
    
    # Calculate all dihedral angles at once
    angles = compute_signed_dihedral_angles_vectorized(atom_quartets)
    
    # Fix the angles based on bond periods
    angles = fix_torsion_angles(angles, bond_atoms_for_angles['bond_periods'])
    return angles


def fix_torsion_angles(angles, bond_periods):
    return (angles + bond_periods / 2) % bond_periods - bond_periods / 2


def pos_to_point3d(pos):
    return Point3D(float(pos[0]), float(pos[1]), float(pos[2]))


def get_bond_properties_for_angles(rotatable_bonds_ext):
    bond_properties_for_angles = {}
    bond_properties_for_angles['start'] = rotatable_bonds_ext.start
    bond_properties_for_angles['end'] = rotatable_bonds_ext.end
    bond_properties_for_angles['neighbor_of_start'] = rotatable_bonds_ext.neighbor_of_start
    bond_properties_for_angles['neighbor_of_end'] = rotatable_bonds_ext.neighbor_of_end
    bond_properties_for_angles['bond_periods'] = rotatable_bonds_ext.bond_periods
    return bond_properties_for_angles


def get_angle_histogram(conf_angle_values, n_bins=100):
    histograms = []
    for angle_ind in range(conf_angle_values.shape[1]):
        hist, _ = np.histogram(conf_angle_values[:, angle_ind], bins=n_bins, range=(-np.pi, np.pi), density=True)
        histograms.append(hist)
    histograms = np.stack(histograms).astype(np.float32)
    return histograms


def compute_angle_density(angles, angle_histograms):
    """
    Compute the density of angles in the batched_angles using the batched_angle_histograms.
    """
    bin_edges = torch.linspace(-np.pi, np.pi, angle_histograms.shape[1] + 1, device=angles.device)
    density_index = torch.sum(bin_edges[:, None] < angles[None, :], dim=0) - 1
    densities = angle_histograms[range(len(angle_histograms)), density_index]
    return densities


def find_rigid_alignment(pos_a, pos_b):
    """
    Borrowed and slightly modified from
    https://gist.github.com/bougui505/23eb8a39d7a601399edc7534b28de3d4

    Outputs rot and tr (with fixed tor components)
    """
    a_mean = pos_a.mean(0)
    b_mean = pos_b.mean(0)
    a_centered = pos_a - a_mean
    b_centered = pos_b - b_mean
    # Covariance matrix
    cov_mat = a_centered.T @ b_centered
    if isinstance(pos_a, torch.Tensor):
        U, _, Vt = torch.linalg.svd(cov_mat)
        V = Vt.T
        det = torch.linalg.det(V @ U.T)
    else:
        U, _, Vt = np.linalg.svd(cov_mat)
        V = Vt.T
        det = np.linalg.det(V @ U.T)

    # Ensure proper rotation by checking determinant
    if det < 0:
        V[:, -1] = -V[:, -1]  # Flip the last column of V
    # Rotation matrix (now guaranteed to be proper rotation)
    rot = V @ U.T
    # Translation vector
    tr = b_mean
    return rot, tr


def compute_angle_error(pred_angles, true_angles, bond_periods):
        # Compute the difference between the predicted and true torsion angles
        if len(bond_periods) > 0:
            diff = (pred_angles - true_angles + \
                    bond_periods / 2) % bond_periods - bond_periods / 2
            # need to average over the bond dimension
            if isinstance(diff, torch.Tensor):
                angle_err = torch.sqrt(torch.sum(diff ** 2)).item()
            else:
                angle_err = np.sqrt(np.sum(diff ** 2))
        else:
            angle_err = 0
        return angle_err


def compute_angle_MAE(pred_angles, true_angles, bond_periods):
        # Compute the difference between the predicted and true torsion angles
        if len(bond_periods) > 0:
            diff = (pred_angles - true_angles + \
                    bond_periods / 2) % bond_periods - bond_periods / 2
            # need to average over the bond dimension
            if isinstance(diff, torch.Tensor):
                diff = torch.abs(diff)
                # print('diff', np.round(diff.cpu().numpy(), 2))
                angle_err = torch.mean(diff).item()
            else:
                diff = np.abs(diff)
                angle_err = np.mean(diff)
        else:
            angle_err = 0
        return angle_err


def compute_batch_angles(batch):
        angles = torch.zeros_like(batch.ligand.rotatable_bonds_ext.angles)
        for i in range(len(batch)):
            # compute torsion angles
            bond_properties_for_angles = {}
            bond_properties_for_angles['start'] = batch.ligand.rotatable_bonds_ext.start[i, :batch.ligand.num_rotatable_bonds[i]]
            bond_properties_for_angles['end'] = batch.ligand.rotatable_bonds_ext.end[i, :batch.ligand.num_rotatable_bonds[i]]
            bond_properties_for_angles['neighbor_of_start'] = batch.ligand.rotatable_bonds_ext.neighbor_of_start[i, :batch.ligand.num_rotatable_bonds[i]]
            bond_properties_for_angles['neighbor_of_end'] = batch.ligand.rotatable_bonds_ext.neighbor_of_end[i, :batch.ligand.num_rotatable_bonds[i]]
            bond_properties_for_angles['bond_periods'] = batch.ligand.rotatable_bonds_ext.bond_periods[i, :batch.ligand.num_rotatable_bonds[i]]

            sample_angles = get_torsion_angles(batch.ligand.pos[i, :batch.ligand.num_atoms[i]], 
                                                bond_atoms_for_angles=bond_properties_for_angles)
            angles[i, :batch.ligand.num_rotatable_bonds[i]] = sample_angles
        return angles


def get_batch_pred_torsion_updates(cur_batch):
    if cur_batch.ligand.pred_tor_angles is not None:
        init_angles = compute_batch_angles(cur_batch)

        # flatten init angles and remove padding
        flattened_init_angles = []
        for i in range(len(cur_batch)):
            flattened_init_angles.append(init_angles[i, :cur_batch.ligand.num_rotatable_bonds[i]])
        flattened_init_angles = torch.cat(flattened_init_angles, dim=0)

        pred_tor_angles = cur_batch.ligand.pred_tor_angles
        pred_tor_mask = cur_batch.ligand.pred_tor_mask

        pred_torsion_updates = flattened_init_angles - pred_tor_angles
        bond_periods = cur_batch.ligand.bond_periods.float()
        pred_torsion_updates = (pred_torsion_updates + bond_periods / 2) % bond_periods - bond_periods / 2
    else:
        pred_tor_mask = torch.zeros_like(cur_batch.ligand.init_tor, dtype=torch.bool)
        pred_torsion_updates = torch.zeros_like(cur_batch.ligand.init_tor)
    return pred_tor_mask, pred_torsion_updates


def convert_separate_velocities_to_atom_level(v_tr, v_rot, v_tor, batch):
    batch_size = len(batch.ligand.pos)
    atom_velocities = torch.zeros_like(batch.ligand.pos)
    center_shift_velocities = torch.zeros(batch_size, 3, device=batch.ligand.pos.device)
    
    # Get ligand centers for rotation
    ligand_centers = compute_batch_ligand_centers(batch)  # [batch_size, 3]
    
    for batch_idx in range(batch_size):
        num_atoms = batch.ligand.num_atoms[batch_idx]
        max_atoms = batch.ligand.pos.shape[1]
        atom_pos = batch.ligand.pos[batch_idx, :num_atoms]  # [num_atoms, 3]

        # 1. Torsion contribution - affects atoms based on rotatable bonds
        if v_tor is not None and batch.ligand.num_rotatable_bonds[batch_idx] > 0:
            # Get rotatable bonds and masks for this batch element
            start_bond_idx = batch.ligand.tor_ptr[batch_idx]
            end_bond_idx = batch.ligand.tor_ptr[batch_idx + 1]
            rotatable_bonds = batch.ligand.rotatable_bonds[start_bond_idx:end_bond_idx]  # [num_rot_bonds, 2]
            batch_v_tor = v_tor[start_bond_idx:end_bond_idx]
            mask_rotate = batch.ligand.mask_rotate[batch_idx]  # [num_rot_bonds, max_atoms]
            num_rotatable_bonds = rotatable_bonds.shape[0]

            for idx_rot_bond in range(num_rotatable_bonds):
                u = rotatable_bonds[idx_rot_bond, 0]
                v = rotatable_bonds[idx_rot_bond, 1]
                
                # Get rotation axis (normalized bond direction)
                axis = atom_pos[u] - atom_pos[v]
                axis = axis / torch.linalg.norm(axis)
                
                # Angular velocity vector: ω = velocity_magnitude * axis_direction
                angular_velocity = batch_v_tor[idx_rot_bond] * axis
                
                # Get affected atoms
                affected_atoms_mask = mask_rotate[idx_rot_bond]
                affected_atoms_mask_padded = torch.zeros(max_atoms, dtype=bool, device=atom_pos.device)
                affected_atoms_mask_padded[:num_atoms] = affected_atoms_mask

                affected_positions = atom_pos[affected_atoms_mask]
                
                if affected_positions.shape[0] > 0:
                    # Rotation center is atom v (as in transforms.py)
                    rotation_center = atom_pos[v]
                    relative_positions = affected_positions - rotation_center
                    
                    # Velocity = ω × (r - center)
                    tor_velocities = torch.cross(
                        angular_velocity.unsqueeze(0).expand(affected_positions.shape[0], -1),
                        relative_positions, 
                        dim=1
                    )                        
                    atom_velocities[batch_idx, affected_atoms_mask_padded] += tor_velocities

            torsion_velocity_sum = atom_velocities[batch_idx, :num_atoms].sum(dim=0)
            center_shift_velocities[batch_idx] = torsion_velocity_sum / num_atoms

        # 2. Translation contribution - same for all atoms
        if v_tr is not None:
            atom_velocities[batch_idx, :num_atoms] += (v_tr[batch_idx] - center_shift_velocities[batch_idx]).unsqueeze(0)
        
        # 3. Rotation contribution - v = ω × (r - center)
        if v_rot is not None:
            center = ligand_centers[batch_idx]  # [3]
            relative_pos = atom_pos - center.unsqueeze(0)  # [num_atoms, 3]
            # Cross product: ω × r
            rot_velocities = torch.cross(v_rot[batch_idx].unsqueeze(0).expand(num_atoms, -1), 
                                            relative_pos, dim=1)
            atom_velocities[batch_idx, :num_atoms] += rot_velocities
        
    return atom_velocities
