from typing import Tuple

import torch

from openfold.np import residue_constants as rc

from openfold.utils import tensor_utils


def get_rc_tensor(rc_np, aatype):
    return torch.tensor(rc_np, device=aatype.device)[aatype]


def atom14_to_atom37(atom14_data: torch.Tensor, aatype: torch.Tensor) -> Tuple:

    idx_atom37_to_atom14 = get_rc_tensor(rc.RESTYPE_ATOM37_TO_ATOM14, aatype).long()
    no_batch_dims = len(aatype.shape) - 1
    atom37_data = tensor_utils.batched_gather(
        atom14_data,
        idx_atom37_to_atom14,
        dim=no_batch_dims + 1,
        no_batch_dims=no_batch_dims + 1,
    )
    atom37_mask = get_rc_tensor(rc.RESTYPE_ATOM37_MASK, aatype)
    if len(atom14_data.shape) == no_batch_dims + 2:
        atom37_data *= atom37_mask
    elif len(atom14_data.shape) == no_batch_dims + 3:
        atom37_data *= atom37_mask[..., None].to(dtype=atom37_data.dtype)
    else:
        raise ValueError("Incorrectly shaped data")
    return atom37_data, atom37_mask
