# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations
from functools import lru_cache
from typing import Tuple, Any, Sequence, Callable, Optional

import numpy as np
import torch
import torch.nn.functional as F
from torch_scatter import scatter_sum

def rot_matmul(
    a: torch.Tensor, 
    b: torch.Tensor
) -> torch.Tensor:
    """
        Performs matrix multiplication of two rotation matrix tensors. Written
        out by hand to avoid AMP downcasting.

        Args:
            a: [*, 3, 3] left multiplicand
            b: [*, 3, 3] right multiplicand
        Returns:
            The product ab
    """
    return a@b

    # torch.einsum('...ij, ...jk -> ...ik', a, b)

    # def row_mul(i):
    #     return torch.stack(
    #         [
    #             a[..., i, 0] * b[..., 0, 0]
    #             + a[..., i, 1] * b[..., 1, 0]
    #             + a[..., i, 2] * b[..., 2, 0],
    #             a[..., i, 0] * b[..., 0, 1]
    #             + a[..., i, 1] * b[..., 1, 1]
    #             + a[..., i, 2] * b[..., 2, 1],
    #             a[..., i, 0] * b[..., 0, 2]
    #             + a[..., i, 1] * b[..., 1, 2]
    #             + a[..., i, 2] * b[..., 2, 2],
    #         ],
    #         dim=-1,
    #     )

    # return torch.stack(
    #     [
    #         row_mul(0), 
    #         row_mul(1), 
    #         row_mul(2),
    #     ], 
    #     dim=-2
    # )


def rot_vec_mul(
    r: torch.Tensor, 
    t: torch.Tensor
) -> torch.Tensor:
    """
        Applies a rotation to a vector. Written out by hand to avoid transfer
        to avoid AMP downcasting.

        Args:
            r: [*, 3, 3] rotation matrices
            t: [*, 3] coordinate tensors
        Returns:
            [*, 3] rotated coordinates
    """
    return torch.einsum('...ij, ...j->...i', r, t)
   
    # x, y, z = torch.unbind(t, dim=-1)
    # return torch.stack(
    #     [
    #         r[..., 0, 0] * x + r[..., 0, 1] * y + r[..., 0, 2] * z,
    #         r[..., 1, 0] * x + r[..., 1, 1] * y + r[..., 1, 2] * z,
    #         r[..., 2, 0] * x + r[..., 2, 1] * y + r[..., 2, 2] * z,
    #     ],
    #     dim=-1,
    # )

@lru_cache(maxsize=None)
def identity_rot_mats(
    batch_dims: Tuple[int], 
    dtype: Optional[torch.dtype] = None, 
    device: Optional[torch.device] = None, 
    requires_grad: bool = True,
) -> torch.Tensor:
    rots = torch.eye(
        3, dtype=dtype, device=device, requires_grad=requires_grad
    )
    rots = rots.view(*((1,) * len(batch_dims)), 3, 3)
    rots = rots.expand(*batch_dims, -1, -1)
    rots = rots.contiguous()

    return rots


@lru_cache(maxsize=None)
def identity_trans(
    batch_dims: Tuple[int], 
    dtype: Optional[torch.dtype] = None,
    device: Optional[torch.device] = None, 
    requires_grad: bool = True,
) -> torch.Tensor:
    trans = torch.zeros(
        (*batch_dims, 3), 
        dtype=dtype, 
        device=device, 
        requires_grad=requires_grad
    )
    return trans


@lru_cache(maxsize=None)
def identity_quats(
    batch_dims: Tuple[int], 
    dtype: Optional[torch.dtype] = None,
    device: Optional[torch.device] = None, 
    requires_grad: bool = True,
) -> torch.Tensor:
    quat = torch.zeros(
        (*batch_dims, 4), 
        dtype=dtype, 
        device=device, 
        requires_grad=requires_grad
    )

    with torch.no_grad():
        quat[..., 0] = 1

    return quat


_quat_elements = ["a", "b", "c", "d"]
_qtr_keys = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements]
_qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)}


def _to_mat(pairs):
    mat = np.zeros((4, 4))
    for pair in pairs:
        key, value = pair
        ind = _qtr_ind_dict[key]
        mat[ind // 4][ind % 4] = value

    return mat


_QTR_MAT = np.zeros((4, 4, 3, 3))
_QTR_MAT[..., 0, 0] = _to_mat([("aa", 1), ("bb", 1), ("cc", -1), ("dd", -1)])
_QTR_MAT[..., 0, 1] = _to_mat([("bc", 2), ("ad", -2)])
_QTR_MAT[..., 0, 2] = _to_mat([("bd", 2), ("ac", 2)])
_QTR_MAT[..., 1, 0] = _to_mat([("bc", 2), ("ad", 2)])
_QTR_MAT[..., 1, 1] = _to_mat([("aa", 1), ("bb", -1), ("cc", 1), ("dd", -1)])
_QTR_MAT[..., 1, 2] = _to_mat([("cd", 2), ("ab", -2)])
_QTR_MAT[..., 2, 0] = _to_mat([("bd", 2), ("ac", -2)])
_QTR_MAT[..., 2, 1] = _to_mat([("cd", 2), ("ab", 2)])
_QTR_MAT[..., 2, 2] = _to_mat([("aa", 1), ("bb", -1), ("cc", -1), ("dd", 1)])


def quat_to_rot(quat: torch.Tensor) -> torch.Tensor:
    """
        Converts a quaternion to a rotation matrix.

        Args:
            quat: [*, 4] quaternions
        Returns:
            [*, 3, 3] rotation matrices
    """
    # [*, 4, 4]
    quat = quat[..., None] * quat[..., None, :]

    # [4, 4, 3, 3]
    mat = _get_quat("_QTR_MAT", dtype=quat.dtype, device=quat.device)

    # [*, 4, 4, 3, 3]
    shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + mat.shape)
    quat = quat[..., None, None] * shaped_qtr_mat

    # [*, 3, 3]
    return torch.sum(quat, dim=(-3, -4))


def rot_to_quat(
    rot: torch.Tensor,
):
    if(rot.shape[-2:] != (3, 3)):
        raise ValueError("Input rotation is incorrectly shaped")

    rot = [[rot[..., i, j] for j in range(3)] for i in range(3)]
    [[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot 

    k = [
        [ xx + yy + zz,      zy - yz,      xz - zx,      yx - xy,],
        [      zy - yz, xx - yy - zz,      xy + yx,      xz + zx,],
        [      xz - zx,      xy + yx, yy - xx - zz,      yz + zy,],
        [      yx - xy,      xz + zx,      yz + zy, zz - xx - yy,]
    ]

    k = (1./3.) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2)

    _, vectors = torch.linalg.eigh(k.float())
    return vectors[..., -1].to(k.dtype)


_QUAT_MULTIPLY = np.zeros((4, 4, 4))
_QUAT_MULTIPLY[:, :, 0] = [[ 1, 0, 0, 0],
                          [ 0,-1, 0, 0],
                          [ 0, 0,-1, 0],
                          [ 0, 0, 0,-1]]

_QUAT_MULTIPLY[:, :, 1] = [[ 0, 1, 0, 0],
                          [ 1, 0, 0, 0],
                          [ 0, 0, 0, 1],
                          [ 0, 0,-1, 0]]

_QUAT_MULTIPLY[:, :, 2] = [[ 0, 0, 1, 0],
                          [ 0, 0, 0,-1],
                          [ 1, 0, 0, 0],
                          [ 0, 1, 0, 0]]

_QUAT_MULTIPLY[:, :, 3] = [[ 0, 0, 0, 1],
                          [ 0, 0, 1, 0],
                          [ 0,-1, 0, 0],
                          [ 1, 0, 0, 0]]

_QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :]

_CACHED_QUATS = {
    "_QTR_MAT": _QTR_MAT,
    "_QUAT_MULTIPLY": _QUAT_MULTIPLY,
    "_QUAT_MULTIPLY_BY_VEC": _QUAT_MULTIPLY_BY_VEC
}

@lru_cache(maxsize=None)
def _get_quat(quat_key, dtype, device):
    return torch.tensor(_CACHED_QUATS[quat_key], dtype=dtype, device=device)


def quat_multiply(quat1, quat2):
    """Multiply a quaternion by another quaternion."""
    mat = _get_quat("_QUAT_MULTIPLY", dtype=quat1.dtype, device=quat1.device)
    reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape)
    return torch.sum(
        reshaped_mat *
        quat1[..., :, None, None] *
        quat2[..., None, :, None],
        dim=(-3, -2)
      )


def quat_multiply_by_vec(quat, vec):
    """Multiply a quaternion by a pure-vector quaternion."""
    mat = _get_quat("_QUAT_MULTIPLY_BY_VEC", dtype=quat.dtype, device=quat.device)
    reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape)
    return torch.sum(
        reshaped_mat *
        quat[..., :, None, None] *
        vec[..., None, :, None],
        dim=(-3, -2)
    )


def invert_rot_mat(rot_mat: torch.Tensor):
    return rot_mat.transpose(-1, -2)


def invert_quat(quat: torch.Tensor):
    quat_prime = quat.clone()
    quat_prime[..., 1:] *= -1
    inv = quat_prime / torch.sum(quat ** 2, dim=-1, keepdim=True)
    return inv


class Rotation:
    """
        A 3D rotation. Depending on how the object is initialized, the
        rotation is represented by either a rotation matrix or a
        quaternion, though both formats are made available by helper functions.
        To simplify gradient computation, the underlying format of the
        rotation cannot be changed in-place. Like Rigid, the class is designed
        to mimic the behavior of a torch Tensor, almost as if each Rotation
        object were a tensor of rotations, in one format or another.
    """
    def __init__(self,
        rot_mats: Optional[torch.Tensor] = None,
        quats: Optional[torch.Tensor] = None,
        normalize_quats: bool = True,
    ):
        """
            Args:
                rot_mats:
                    A [*, 3, 3] rotation matrix tensor. Mutually exclusive with
                    quats
                quats:
                    A [*, 4] quaternion. Mutually exclusive with rot_mats. If
                    normalize_quats is not True, must be a unit quaternion
                normalize_quats:
                    If quats is specified, whether to normalize quats
        """
        if((rot_mats is None and quats is None) or 
            (rot_mats is not None and quats is not None)):
            raise ValueError("Exactly one input argument must be specified")

        if((rot_mats is not None and rot_mats.shape[-2:] != (3, 3)) or 
            (quats is not None and quats.shape[-1] != 4)):
            raise ValueError(
                "Incorrectly shaped rotation matrix or quaternion"
            )

        # # Force full-precision
        # if(quats is not None):
        #     quats = quats.to(dtype=torch.float32)
        # if(rot_mats is not None):
        #     rot_mats = rot_mats.to(dtype=torch.float32)

        if(quats is not None and normalize_quats):
            quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True)

        self._rot_mats = rot_mats
        self._quats = quats

    @staticmethod
    def identity(
        shape,
        dtype: Optional[torch.dtype] = None,
        device: Optional[torch.device] = None,
        requires_grad: bool = True,
        fmt: str = "quat",
    ) -> Rotation:
        """
            Returns an identity Rotation.

            Args:
                shape:
                    The "shape" of the resulting Rotation object. See documentation
                    for the shape property
                dtype:
                    The torch dtype for the rotation
                device:
                    The torch device for the new rotation
                requires_grad:
                    Whether the underlying tensors in the new rotation object
                    should require gradient computation
                fmt:
                    One of "quat" or "rot_mat". Determines the underlying format
                    of the new object's rotation 
            Returns:
                A new identity rotation
        """
        if(fmt == "rot_mat"):
            rot_mats = identity_rot_mats(
                shape, dtype, device, requires_grad,
            )
            return Rotation(rot_mats=rot_mats, quats=None)
        elif(fmt == "quat"):
            quats = identity_quats(shape, dtype, device, requires_grad)
            return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
        else:
            raise ValueError(f"Invalid format: f{fmt}")

    # Magic methods

    def __getitem__(self, index: Any) -> Rotation:
        """
            Allows torch-style indexing over the virtual shape of the rotation
            object. See documentation for the shape property.

            Args:
                index:
                    A torch index. E.g. (1, 3, 2), or (slice(None,))
            Returns:
                The indexed rotation
        """
        if type(index) != tuple:
            index = (index,)

        if(self._rot_mats is not None):
            rot_mats = self._rot_mats[index + (slice(None), slice(None))]
            return Rotation(rot_mats=rot_mats)
        elif(self._quats is not None):
            quats = self._quats[index + (slice(None),)]
            return Rotation(quats=quats, normalize_quats=False)
        else:
            raise ValueError("Both rotations are None")

    def __mul__(self,
        right: torch.Tensor,
    ) -> Rotation:
        """
            Pointwise left multiplication of the rotation with a tensor. Can be
            used to e.g. mask the Rotation.

            Args:
                right:
                    The tensor multiplicand
            Returns:
                The product
        """
        if not(isinstance(right, torch.Tensor)):
            raise TypeError("The other multiplicand must be a Tensor")

        if(self._rot_mats is not None):
            rot_mats = self._rot_mats * right[..., None, None]
            return Rotation(rot_mats=rot_mats, quats=None)
        elif(self._quats is not None):
            quats = self._quats * right[..., None]
            return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
        else:
            raise ValueError("Both rotations are None")

    def __rmul__(self,
        left: torch.Tensor,
    ) -> Rotation:
        """
            Reverse pointwise multiplication of the rotation with a tensor.

            Args:
                left:
                    The left multiplicand
            Returns:
                The product
        """
        return self.__mul__(left)
    
    # Properties

    @property
    def shape(self) -> torch.Size:
        """
            Returns the virtual shape of the rotation object. This shape is
            defined as the batch dimensions of the underlying rotation matrix
            or quaternion. If the Rotation was initialized with a [10, 3, 3]
            rotation matrix tensor, for example, the resulting shape would be
            [10].
        
            Returns:
                The virtual shape of the rotation object
        """
        s = None
        if(self._quats is not None):
            s = self._quats.shape[:-1]
        else:
            s = self._rot_mats.shape[:-2]

        return s

    @property
    def dtype(self) -> torch.dtype:
        """
            Returns the dtype of the underlying rotation.

            Returns:
                The dtype of the underlying rotation
        """
        if(self._rot_mats is not None):
            return self._rot_mats.dtype
        elif(self._quats is not None):
            return self._quats.dtype
        else:
            raise ValueError("Both rotations are None")

    @property
    def device(self) -> torch.device:
        """
            The device of the underlying rotation

            Returns:
                The device of the underlying rotation
        """
        if(self._rot_mats is not None):
            return self._rot_mats.device
        elif(self._quats is not None):
            return self._quats.device
        else:
            raise ValueError("Both rotations are None")

    @property
    def requires_grad(self) -> bool:
        """
            Returns the requires_grad property of the underlying rotation

            Returns:
                The requires_grad property of the underlying tensor
        """
        if(self._rot_mats is not None):
            return self._rot_mats.requires_grad
        elif(self._quats is not None):
            return self._quats.requires_grad
        else:
            raise ValueError("Both rotations are None")

    def get_rot_mats(self) -> torch.Tensor:
        """
            Returns the underlying rotation as a rotation matrix tensor.

            Returns:
                The rotation as a rotation matrix tensor
        """
        rot_mats = self._rot_mats
        if(rot_mats is None):
            if(self._quats is None):
                raise ValueError("Both rotations are None")
            else:
                rot_mats = quat_to_rot(self._quats)

        return rot_mats 

    def get_quats(self) -> torch.Tensor:
        """
            Returns the underlying rotation as a quaternion tensor.

            Depending on whether the Rotation was initialized with a
            quaternion, this function may call torch.linalg.eigh.

            Returns:
                The rotation as a quaternion tensor.
        """
        quats = self._quats
        if(quats is None):
            if(self._rot_mats is None):
                raise ValueError("Both rotations are None")
            else:
                quats = rot_to_quat(self._rot_mats)

        return quats

    def get_cur_rot(self) -> torch.Tensor:
        """
            Return the underlying rotation in its current form

            Returns:
                The stored rotation
        """
        if(self._rot_mats is not None):
            return self._rot_mats
        elif(self._quats is not None):
            return self._quats
        else:
            raise ValueError("Both rotations are None")

    # Rotation functions

    def compose_q_update_vec(self, 
        q_update_vec: torch.Tensor, 
        normalize_quats: bool = True
    ) -> Rotation:
        """
            Returns a new quaternion Rotation after updating the current
            object's underlying rotation with a quaternion update, formatted
            as a [*, 3] tensor whose final three columns represent x, y, z such 
            that (1, x, y, z) is the desired (not necessarily unit) quaternion
            update.

            Args:
                q_update_vec:
                    A [*, 3] quaternion update tensor
                normalize_quats:
                    Whether to normalize the output quaternion
            Returns:
                An updated Rotation
        """
        quats = self.get_quats()
        new_quats = quats + quat_multiply_by_vec(quats, q_update_vec)
        return Rotation(
            rot_mats=None, 
            quats=new_quats, 
            normalize_quats=normalize_quats,
        )

    def compose_r(self, r: Rotation) -> Rotation:
        """
            Compose the rotation matrices of the current Rotation object with
            those of another.

            Args:
                r:
                    An update rotation object
            Returns:
                An updated rotation object
        """
        r1 = self.get_rot_mats()
        r2 = r.get_rot_mats()
        new_rot_mats = rot_matmul(r1, r2)
        return Rotation(rot_mats=new_rot_mats, quats=None)

    def compose_q(self, r: Rotation, normalize_quats: bool = True) -> Rotation:
        """
            Compose the quaternions of the current Rotation object with those
            of another.

            Depending on whether either Rotation was initialized with
            quaternions, this function may call torch.linalg.eigh.

            Args:
                r:
                    An update rotation object
            Returns:
                An updated rotation object
        """
        q1 = self.get_quats()
        q2 = r.get_quats()
        new_quats = quat_multiply(q1, q2)
        return Rotation(
            rot_mats=None, quats=new_quats, normalize_quats=normalize_quats
        )

    def apply(self, pts: torch.Tensor) -> torch.Tensor:
        """
            Apply the current Rotation as a rotation matrix to a set of 3D
            coordinates.

            Args:
                pts:
                    A [*, 3] set of points
            Returns:
                [*, 3] rotated points
        """
        rot_mats = self.get_rot_mats()
        return rot_vec_mul(rot_mats, pts)

    def invert_apply(self, pts: torch.Tensor) -> torch.Tensor:
        """
            The inverse of the apply() method.

            Args:
                pts:
                    A [*, 3] set of points
            Returns:
                [*, 3] inverse-rotated points
        """
        rot_mats = self.get_rot_mats()
        inv_rot_mats = invert_rot_mat(rot_mats) 
        return rot_vec_mul(inv_rot_mats, pts)

    def invert(self) -> Rotation:
        """
            Returns the inverse of the current Rotation.

            Returns:
                The inverse of the current Rotation
        """
        if(self._rot_mats is not None):
            return Rotation(
                rot_mats=invert_rot_mat(self._rot_mats), 
                quats=None
            )
        elif(self._quats is not None):
            return Rotation(
                rot_mats=None,
                quats=invert_quat(self._quats),
                normalize_quats=False,
            )
        else:
            raise ValueError("Both rotations are None")

    # "Tensor" stuff

    def unsqueeze(self, 
        dim: int,
    ) -> Rigid:
        """
            Analogous to torch.unsqueeze. The dimension is relative to the
            shape of the Rotation object.
            
            Args:
                dim: A positive or negative dimension index.
            Returns:
                The unsqueezed Rotation.
        """
        if dim >= len(self.shape):
            raise ValueError("Invalid dimension")

        if(self._rot_mats is not None):
            rot_mats = self._rot_mats.unsqueeze(dim if dim >= 0 else dim - 2)
            return Rotation(rot_mats=rot_mats, quats=None)
        elif(self._quats is not None):
            quats = self._quats.unsqueeze(dim if dim >= 0 else dim - 1)
            return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
        else:
            raise ValueError("Both rotations are None")

    @staticmethod
    def cat(
        rs: Sequence[Rotation], 
        dim: int,
    ) -> Rigid:
        """
            Concatenates rotations along one of the batch dimensions. Analogous
            to torch.cat().

            Note that the output of this operation is always a rotation matrix,
            regardless of the format of input rotations.

            Args:
                rs: 
                    A list of rotation objects
                dim: 
                    The dimension along which the rotations should be 
                    concatenated
            Returns:
                A concatenated Rotation object in rotation matrix format
        """
        rot_mats = [r.get_rot_mats() for r in rs]
        rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2)

        return Rotation(rot_mats=rot_mats, quats=None) 
    

    def reshape(
        self, 
        shape,
    ) -> Rigid:
        rot_mats = self.get_rot_mats()
        rot_mats = rot_mats.reshape(*shape,3,3)
        return Rotation(rot_mats=rot_mats, quats=None) 

    def map_tensor_fn(self, 
        fn: Callable[[torch.Tensor], torch.Tensor]
    ) -> Rotation:
        """
            Apply a Tensor -> Tensor function to underlying rotation tensors,
            mapping over the rotation dimension(s). Can be used e.g. to sum out
            a one-hot batch dimension.

            Args:
                fn:
                    A Tensor -> Tensor function to be mapped over the Rotation 
            Returns:
                The transformed Rotation object
        """ 
        if(self._rot_mats is not None):
            rot_mats = self._rot_mats.view(self._rot_mats.shape[:-2] + (9,))
            rot_mats = torch.stack(
                list(map(fn, torch.unbind(rot_mats, dim=-1))), dim=-1
            )
            rot_mats = rot_mats.view(rot_mats.shape[:-1] + (3, 3))
            return Rotation(rot_mats=rot_mats, quats=None)
        elif(self._quats is not None):
            quats = torch.stack(
                list(map(fn, torch.unbind(self._quats, dim=-1))), dim=-1
            )
            return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
        else:
            raise ValueError("Both rotations are None")
    
    def cuda(self, **kwargs) -> Rotation:
        """
            Analogous to the cuda() method of torch Tensors

            Returns:
                A copy of the Rotation in CUDA memory
        """
        if(self._rot_mats is not None):
            return Rotation(rot_mats=self._rot_mats.cuda(**kwargs), quats=None)
        elif(self._quats is not None):
            return Rotation(
                rot_mats=None, 
                quats=self._quats.cuda(**kwargs),
                normalize_quats=False
            )
        else:
            raise ValueError("Both rotations are None")

    def to(self, 
        **kwargs
    ) -> Rotation:
        """
            Analogous to the to() method of torch Tensors

            Args:
                device:
                    A torch device
                dtype:
                    A torch dtype
            Returns:
                A copy of the Rotation using the new device and dtype
        """
        if(self._rot_mats is not None):
            return Rotation(
                rot_mats=self._rot_mats.to(**kwargs), 
                quats=None,
            )
        elif(self._quats is not None):
            return Rotation(
                rot_mats=None, 
                quats=self._quats.to(**kwargs),
                normalize_quats=False,
            )
        else:
            raise ValueError("Both rotations are None")

    def detach(self) -> Rotation:
        """
            Returns a copy of the Rotation whose underlying Tensor has been
            detached from its torch graph.

            Returns:
                A copy of the Rotation whose underlying Tensor has been detached
                from its torch graph
        """
        if(self._rot_mats is not None):
            return Rotation(rot_mats=self._rot_mats.detach(), quats=None)
        elif(self._quats is not None):
            return Rotation(
                rot_mats=None, 
                quats=self._quats.detach(), 
                normalize_quats=False,
            )
        else:
            raise ValueError("Both rotations are None")


class Rigid:
    """
        A class representing a rigid transformation. Little more than a wrapper
        around two objects: a Rotation object and a [*, 3] translation
        Designed to behave approximately like a single torch tensor with the 
        shape of the shared batch dimensions of its component parts.
    """
    def __init__(self, 
        rots: Optional[Rotation],
        trans: Optional[torch.Tensor],
    ):
        """
            Args:
                rots: A [*, 3, 3] rotation tensor
                trans: A corresponding [*, 3] translation tensor
        """
        # (we need device, dtype, etc. from at least one input)

        batch_dims, dtype, device, requires_grad = None, None, None, None
        if(trans is not None):
            batch_dims = trans.shape[:-1]
            dtype = trans.dtype
            device = trans.device
            requires_grad = trans.requires_grad
        elif(rots is not None):
            batch_dims = rots.shape
            dtype = rots.dtype
            device = rots.device
            requires_grad = rots.requires_grad
        else:
            raise ValueError("At least one input argument must be specified")

        if(rots is None):
            rots = Rotation.identity(
                batch_dims, dtype, device, requires_grad,
            )
        elif(trans is None):
            trans = identity_trans(
                batch_dims, dtype, device, requires_grad,
            )

        if((rots.shape != trans.shape[:-1]) or
           (rots.device != trans.device)):
            raise ValueError("Rots and trans incompatible")

        # # Force full precision. Happens to the rotations automatically.
        # trans = trans.to(dtype=torch.float32)

        self._rots = rots
        self._trans = trans

    @staticmethod
    def identity(
        shape: Tuple[int], 
        dtype: Optional[torch.dtype] = None,
        device: Optional[torch.device] = None, 
        requires_grad: bool = True,
        fmt: str = "quat",
    ) -> Rigid:
        """
            Constructs an identity transformation.

            Args:
                shape: 
                    The desired shape
                dtype: 
                    The dtype of both internal tensors
                device: 
                    The device of both internal tensors
                requires_grad: 
                    Whether grad should be enabled for the internal tensors
            Returns:
                The identity transformation
        """
        return Rigid(
            Rotation.identity(shape, dtype, device, requires_grad, fmt=fmt),
            identity_trans(shape, dtype, device, requires_grad),
        )

    def __getitem__(self, 
        index: Any,
    ) -> Rigid:
        """ 
            Indexes the affine transformation with PyTorch-style indices.
            The index is applied to the shared dimensions of both the rotation
            and the translation.

            E.g.::

                r = Rotation(rot_mats=torch.rand(10, 10, 3, 3), quats=None)
                t = Rigid(r, torch.rand(10, 10, 3))
                indexed = t[3, 4:6]
                assert(indexed.shape == (2,))
                assert(indexed.get_rots().shape == (2,))
                assert(indexed.get_trans().shape == (2, 3))

            Args:
                index: A standard torch tensor index. E.g. 8, (10, None, 3),
                or (3, slice(0, 1, None))
            Returns:
                The indexed tensor 
        """
        if type(index) != tuple:
            index = (index,)
        
        return Rigid(
            self._rots[index],
            self._trans[index + (slice(None),)],
        )

    def __mul__(self,
        right: torch.Tensor,
    ) -> Rigid:
        """
            Pointwise left multiplication of the transformation with a tensor.
            Can be used to e.g. mask the Rigid.

            Args:
                right:
                    The tensor multiplicand
            Returns:
                The product
        """
        if not(isinstance(right, torch.Tensor)):
            raise TypeError("The other multiplicand must be a Tensor")

        new_rots = self._rots * right
        new_trans = self._trans * right[..., None]

        return Rigid(new_rots, new_trans)

    def __rmul__(self,
        left: torch.Tensor,
    ) -> Rigid:
        """
            Reverse pointwise multiplication of the transformation with a 
            tensor.

            Args:
                left:
                    The left multiplicand
            Returns:
                The product
        """
        return self.__mul__(left)

    @property
    def shape(self) -> torch.Size:
        """
            Returns the shape of the shared dimensions of the rotation and
            the translation.
            
            Returns:
                The shape of the transformation
        """
        s = self._trans.shape[:-1]
        return s

    @property
    def device(self) -> torch.device:
        """
            Returns the device on which the Rigid's tensors are located.

            Returns:
                The device on which the Rigid's tensors are located
        """
        return self._trans.device

    def get_rots(self) -> Rotation:
        """
            Getter for the rotation.

            Returns:
                The rotation object
        """
        return self._rots

    def get_trans(self) -> torch.Tensor:
        """
            Getter for the translation.

            Returns:
                The stored translation
        """
        return self._trans

    def compose_q_update_vec(self, 
        q_update_vec: torch.Tensor,
    ) -> Rigid:
        """
            Composes the transformation with a quaternion update vector of
            shape [*, 6], where the final 6 columns represent the x, y, and
            z values of a quaternion of form (1, x, y, z) followed by a 3D
            translation.

            Args:
                q_vec: The quaternion update vector.
            Returns:
                The composed transformation.
        """
        q_vec, t_vec = q_update_vec[..., :3], q_update_vec[..., 3:]
        new_rots = self._rots.compose_q_update_vec(q_vec)

        trans_update = self._rots.apply(t_vec)
        new_translation = self._trans + trans_update

        return Rigid(new_rots, new_translation)

    def compose(self,
        r: Rigid,
    ) -> Rigid:
        """
            Composes the current rigid object with another.

            Args:
                r:
                    Another Rigid object
            Returns:
                The composition of the two transformations
        """
        new_rot = self._rots.compose_r(r._rots)
        new_trans = self._rots.apply(r._trans) + self._trans
        return Rigid(new_rot, new_trans)

    def apply(self, 
        pts: torch.Tensor,
    ) -> torch.Tensor:
        """
            Applies the transformation to a coordinate tensor.

            Args:
                pts: A [*, 3] coordinate tensor.
            Returns:
                The transformed points.
        """
        rotated = self._rots.apply(pts) 
        return rotated + self._trans

    def invert_apply(self, 
        pts: torch.Tensor
    ) -> torch.Tensor:
        """
            Applies the inverse of the transformation to a coordinate tensor.

            Args:
                pts: A [*, 3] coordinate tensor
            Returns:
                The transformed points.
        """
        pts = pts - self._trans
        return self._rots.invert_apply(pts) 

    def invert(self) -> Rigid:
        """
            Inverts the transformation.

            Returns:
                The inverse transformation.
        """
        rot_inv = self._rots.invert() 
        trn_inv = rot_inv.apply(self._trans)

        return Rigid(rot_inv, -1 * trn_inv)

    def map_tensor_fn(self, 
        fn: Callable[[torch.Tensor], torch.Tensor]
    ) -> Rigid:
        """
            Apply a Tensor -> Tensor function to underlying translation and
            rotation tensors, mapping over the translation/rotation dimensions
            respectively.

            Args:
                fn:
                    A Tensor -> Tensor function to be mapped over the Rigid
            Returns:
                The transformed Rigid object
        """     
        new_rots = self._rots.map_tensor_fn(fn) 
        new_trans = torch.stack(
            list(map(fn, torch.unbind(self._trans, dim=-1))), 
            dim=-1
        )

        return Rigid(new_rots, new_trans)

    def to_tensor_4x4(self) -> torch.Tensor:
        """
            Converts a transformation to a homogenous transformation tensor.

            Returns:
                A [*, 4, 4] homogenous transformation tensor
        """
        tensor = self._trans.new_zeros((*self.shape, 4, 4))
        tensor[..., :3, :3] = self._rots.get_rot_mats()
        tensor[..., :3, 3] = self._trans
        tensor[..., 3, 3] = 1
        return tensor

    @staticmethod
    def from_tensor_4x4(
        t: torch.Tensor
    ) -> Rigid:
        """
            Constructs a transformation from a homogenous transformation
            tensor.

            Args:
                t: [*, 4, 4] homogenous transformation tensor
            Returns:
                T object with shape [*]
        """
        if(t.shape[-2:] != (4, 4)):
            raise ValueError("Incorrectly shaped input tensor")

        rots = Rotation(rot_mats=t[..., :3, :3], quats=None)
        trans = t[..., :3, 3]
        
        return Rigid(rots, trans)

    def to_tensor_7(self) -> torch.Tensor:
        """
            Converts a transformation to a tensor with 7 final columns, four 
            for the quaternion followed by three for the translation.

            Returns:
                A [*, 7] tensor representation of the transformation
        """
        tensor = self._trans.new_zeros((*self.shape, 7))
        tensor[..., :4] = self._rots.get_quats()
        tensor[..., 4:] = self._trans

        return tensor

    @staticmethod
    def from_tensor_7(
        t: torch.Tensor,
        normalize_quats: bool = False,
    ) -> Rigid:
        if(t.shape[-1] != 7):
            raise ValueError("Incorrectly shaped input tensor")

        quats, trans = t[..., :4], t[..., 4:]

        rots = Rotation(
            rot_mats=None, 
            quats=quats, 
            normalize_quats=normalize_quats
        )

        return Rigid(rots, trans)

    @staticmethod
    def from_3_points(
        p_neg_x_axis: torch.Tensor, 
        origin: torch.Tensor, 
        p_xy_plane: torch.Tensor, 
        eps: float = 1e-8
    ) -> Rigid:
        """
            Implements algorithm 21. Constructs transformations from sets of 3 
            points using the Gram-Schmidt algorithm.

            Args:
                p_neg_x_axis: [*, 3] coordinates
                origin: [*, 3] coordinates used as frame origins
                p_xy_plane: [*, 3] coordinates
                eps: Small epsilon value
            Returns:
                A transformation object of shape [*]
        """
        p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1)
        origin = torch.unbind(origin, dim=-1)
        p_xy_plane = torch.unbind(p_xy_plane, dim=-1)

        e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)]
        e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)]

        denom = torch.sqrt(sum((c * c for c in e0)) + eps)
        e0 = [c / denom for c in e0]
        dot = sum((c1 * c2 for c1, c2 in zip(e0, e1)))
        e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)]
        denom = torch.sqrt(sum((c * c for c in e1)) + eps)
        e1 = [c / denom for c in e1]
        e2 = [
            e0[1] * e1[2] - e0[2] * e1[1],
            e0[2] * e1[0] - e0[0] * e1[2],
            e0[0] * e1[1] - e0[1] * e1[0],
        ]

        rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1)
        rots = rots.reshape(rots.shape[:-1] + (3, 3))

        rot_obj = Rotation(rot_mats=rots, quats=None)

        return Rigid(rot_obj, torch.stack(origin, dim=-1))

    def unsqueeze(self, 
        dim: int,
    ) -> Rigid:
        """
            Analogous to torch.unsqueeze. The dimension is relative to the
            shared dimensions of the rotation/translation.
            
            Args:
                dim: A positive or negative dimension index.
            Returns:
                The unsqueezed transformation.
        """
        if dim >= len(self.shape):
            raise ValueError("Invalid dimension")
        rots = self._rots.unsqueeze(dim)
        trans = self._trans.unsqueeze(dim if dim >= 0 else dim - 1)

        return Rigid(rots, trans)

    @staticmethod
    def cat(
        ts: Sequence[Rigid], 
        dim: int,
    ) -> Rigid:
        """
            Concatenates transformations along a new dimension.

            Args:
                ts: 
                    A list of T objects
                dim: 
                    The dimension along which the transformations should be 
                    concatenated
            Returns:
                A concatenated transformation object
        """
        rots = Rotation.cat([t._rots for t in ts], dim) 
        trans = torch.cat(
            [t._trans for t in ts], dim=dim if dim >= 0 else dim - 1
        )

        return Rigid(rots, trans)
    

    def reshape(
        self,
        shape: Sequence[int], 
    ) -> Rigid:
        rots = self._rots.reshape(shape)
        trans = self._trans.reshape(*shape, 3)

        return Rigid(rots, trans)

    def apply_rot_fn(self, fn: Callable[[Rotation], Rotation]) -> Rigid:
        """
            Applies a Rotation -> Rotation function to the stored rotation
            object.

            Args:
                fn: A function of type Rotation -> Rotation
            Returns:
                A transformation object with a transformed rotation.
        """
        return Rigid(fn(self._rots), self._trans)

    def apply_trans_fn(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Rigid:
        """
            Applies a Tensor -> Tensor function to the stored translation.

            Args:
                fn: 
                    A function of type Tensor -> Tensor to be applied to the
                    translation
            Returns:
                A transformation object with a transformed translation.
        """
        return Rigid(self._rots, fn(self._trans))

    def scale_translation(self, trans_scale_factor: float) -> Rigid:
        """
            Scales the translation by a constant factor.

            Args:
                trans_scale_factor:
                    The constant factor
            Returns:
                A transformation object with a scaled translation.
        """
        fn = lambda t: t * trans_scale_factor
        return self.apply_trans_fn(fn)

    def stop_rot_gradient(self) -> Rigid:
        """
            Detaches the underlying rotation object

            Returns:
                A transformation object with detached rotations
        """
        fn = lambda r: r.detach()
        return self.apply_rot_fn(fn)

    @staticmethod
    def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20):
        """
            Returns a transformation object from reference coordinates.
  
            Note that this method does not take care of symmetries. If you 
            provide the atom positions in the non-standard way, the N atom will 
            end up not at [-0.527250, 1.359329, 0.0] but instead at 
            [-0.527250, -1.359329, 0.0]. You need to take care of such cases in 
            your code.
  
            Args:
                n_xyz: A [*, 3] tensor of nitrogen xyz coordinates.
                ca_xyz: A [*, 3] tensor of carbon alpha xyz coordinates.
                c_xyz: A [*, 3] tensor of carbon xyz coordinates.
            Returns:
                A transformation object. After applying the translation and 
                rotation to the reference backbone, the coordinates will 
                approximately equal to the input coordinates.
        """    
        # translation = -1 * ca_xyz
        # n_xyz = n_xyz + translation
        # c_xyz = c_xyz + translation

        # c_x, c_y, c_z = [c_xyz[..., i] for i in range(3)]
        # norm = torch.sqrt(eps + c_x ** 2 + c_y ** 2)
        # sin_c1 = -c_y / norm
        # cos_c1 = c_x / norm
        # zeros = sin_c1.new_zeros(sin_c1.shape)
        # ones = sin_c1.new_ones(sin_c1.shape)

        # c1_rots = sin_c1.new_zeros((*sin_c1.shape, 3, 3))
        # c1_rots[..., 0, 0] = cos_c1
        # c1_rots[..., 0, 1] = -1 * sin_c1
        # c1_rots[..., 1, 0] = sin_c1
        # c1_rots[..., 1, 1] = cos_c1
        # c1_rots[..., 2, 2] = 1

        # norm = torch.sqrt(eps + c_x ** 2 + c_y ** 2 + c_z ** 2)
        # sin_c2 = c_z / norm
        # cos_c2 = torch.sqrt(c_x ** 2 + c_y ** 2) / norm

        # c2_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
        # c2_rots[..., 0, 0] = cos_c2
        # c2_rots[..., 0, 2] = sin_c2
        # c2_rots[..., 1, 1] = 1
        # c2_rots[..., 2, 0] = -1 * sin_c2
        # c2_rots[..., 2, 2] = cos_c2

        # c_rots = rot_matmul(c2_rots, c1_rots)
        # n_xyz = rot_vec_mul(c_rots, n_xyz)

        # _, n_y, n_z = [n_xyz[..., i] for i in range(3)]
        # norm = torch.sqrt(eps + n_y ** 2 + n_z ** 2)
        # sin_n = -n_z / norm
        # cos_n = n_y / norm

        # n_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
        # n_rots[..., 0, 0] = 1
        # n_rots[..., 1, 1] = cos_n
        # n_rots[..., 1, 2] = -1 * sin_n
        # n_rots[..., 2, 1] = sin_n
        # n_rots[..., 2, 2] = cos_n

        # rots = rot_matmul(n_rots, c_rots)

        # rots = rots.transpose(-1, -2)
        # translation = -1 * translation


        def build_local_coordinate_system(p1, p2, p3, norm_eps=1e-8):
            """
            Build a local coordinate system from three points and represent it as a 3x3 rotation matrix.

            Args:
            - p1, p2, p3: Tensor of shape (batch_size, 3) representing the coordinates of three points.

            Returns:
            - origin: Tensor of shape (batch_size, 3) representing the origin of the local coordinate system.
            - rotation_matrix: Tensor of shape (batch_size, 3, 3) representing the rotation matrix of the local coordinate system.
            """
            # Calculate the vectors of the coordinate system
            v1 = p2 - p1
            v2 = p3 - p1

            # Calculate the normal vector to the plane defined by v1 and v2
            normal = torch.cross(v1, v2, dim=1)

            # Normalize the normal vector
            normal_norm = torch.norm(normal, p=2, dim=1, keepdim=True)
            normal = normal / (normal_norm + norm_eps)

            # Calculate the x-axis of the coordinate system
            axis_x = v1 / (torch.norm(v1, p=2, dim=1, keepdim=True) + norm_eps)

            # Calculate the y-axis of the coordinate system
            axis_y = torch.cross(normal, axis_x, dim=1)

            # Calculate the z-axis of the coordinate system
            axis_z = normal

            # Stack the axis vectors to form the rotation matrix
            rotation_matrix = torch.stack([axis_x, axis_y, axis_z], dim=2)

            return p1, rotation_matrix

        translation, rots = build_local_coordinate_system(n_xyz, ca_xyz, c_xyz)
        rot_obj = Rotation(rot_mats=rots, quats=None)

        return Rigid(rot_obj, translation)


    

    def cuda(self, **kwargs) -> Rigid:
        """
            Moves the transformation object to GPU memory
            
            Returns:
                A version of the transformation on GPU
        """
        return Rigid(self._rots.cuda(**kwargs), self._trans.cuda(**kwargs))

    def to(self,**kwargs) -> Rigid:
        return Rigid(self._rots.to(**kwargs), self._trans.to(**kwargs))

def positional_embeddings(E_idx, dtype, num_embeddings=None):
    # From https://github.com/jingraham/neurips19-graph-protein-design
    num_embeddings = num_embeddings
    d = E_idx[0]-E_idx[1]
    
    frequency = torch.exp(
        torch.arange(0, num_embeddings, 2, dtype=dtype, device=E_idx.device)
        * -(np.log(10000.0) / num_embeddings)
    )
    angles = d[:,None] * frequency[None,:]
    E = torch.cat((torch.cos(angles), torch.sin(angles)), -1)
    return E

def positional_embeddings_transformer(d, num_embeddings=None):
    # From https://github.com/jingraham/neurips19-graph-protein-design
    frequency = torch.exp(
        torch.arange(0, num_embeddings, 2, dtype=torch.float32, device=d.device)
        * -(np.log(10000.0) / num_embeddings)
    )
    angles = d[:,:,:,None] * frequency[None,None,None,:]
    E = torch.cat((torch.cos(angles), torch.sin(angles)), -1)
    return E


def get_interact_feats(T, T_ts, X, edge_idx, batch_id, num_rbf=16):
    dtype = X.dtype
    device = X.device
    src_idx, dst_idx = edge_idx[0], edge_idx[1]
    num_N, num_E = X.shape[0], edge_idx.shape[1]

    def rbf_func(D, num_rbf):
        shape = D.shape
        D_min, D_max, D_count = 0., 20., num_rbf
        D_mu = torch.linspace(D_min, D_max, D_count, dtype=dtype, device=device)
        D_mu = D_mu.view([1]*(len(shape))+[-1])
        D_sigma = (D_max - D_min) / D_count
        D_expand = torch.unsqueeze(D, -1)
        RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2)
        return RBF

    def decouple(U):
        norm = U.norm(dim=-1, keepdim=True)
        direct = U/(norm+1e-6)
        rbf = rbf_func(norm[...,0], num_rbf)
        return torch.cat([direct, rbf], dim=-1)
    
    ## ========== new_simplified_feat_diff_vec
    diffX = F.pad(X.reshape(-1,3).diff(dim=0), (0,0,1,0)).reshape(num_N, -1, 3)
    diffX_proj = T[:,None].invert()._rots.apply(diffX)
    V_prev = F.pad(diffX_proj[:-1],(0,0,0,0,1,0))
    V_current = diffX_proj
    V_next = F.pad(diffX[1:],(0,0,0,0,0,1))
    V_next2 = F.pad(diffX[2:],(0,0,0,0,0,2))
    

    # V_prev = T[:,None].invert()._rots.apply(F.pad(diffX[1:],(0,0,0,0,1,0)))
    # V_current = T[:,None].invert()._rots.apply(diffX)
    # V_next = T[:,None].invert()._rots.apply(F.pad(diffX[:-1],(0,0,0,0,0,1)))
    # V_next2 = T[:,None].invert()._rots.apply(F.pad(diffX[:-2],(0,0,0,0,0,2)))
    # V = torch.cat([ decouple(V_prev).reshape(num_N, -1),
    #                 decouple(V_current).reshape(num_N, -1),
    #                 decouple(V_next).reshape(num_N, -1),
    #                 decouple(V_next2).reshape(num_N, -1)
    #                 ], dim=1)

    V = decouple(torch.cat([V_prev, V_current, V_next, V_next2], dim=1)).reshape(num_N, -1)
    V[torch.isnan(V)] = 0



    '''X [N,4,3]: N个氨基酸, 每个氨基酸4个原子(N,CA,C,O), 3是原子的xyz坐标
        T [N]: N个局部坐标系
    '''
    src_idx, dst_idx = edge_idx[0], edge_idx[1]
    
    diffE = T[src_idx,None].invert().apply(torch.cat([X[src_idx],X[dst_idx]], dim=1))
    diffE = decouple(diffE).reshape(num_E, -1)

    pos_embed = positional_embeddings(edge_idx, dtype, 16)
    # E = decouple(E).reshape(num_E,-1)
    E_quant = T_ts.invert()._rots._rot_mats.reshape(num_E,9)
    E_trans = T_ts._trans
    E_trans = decouple(E_trans).reshape(num_E,-1)
    E = torch.cat([diffE, E_quant, E_trans, pos_embed], dim=-1)
    return V, E



def get_interact_feats_transformer(T, T_ij, X, num_rbf=16):
    def rbf_func(D, num_rbf):
        shape = D.shape
        D_min, D_max, D_count = 0., 20., num_rbf
        D_mu = torch.linspace(D_min, D_max, D_count).to(D.device)
        D_mu = D_mu.view([1]*(len(shape))+[-1])
        D_sigma = (D_max - D_min) / D_count
        D_expand = torch.unsqueeze(D, -1)
        RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2)
        return RBF

    def decouple(U):
        norm = U.norm(dim=-1, keepdim=True)
        direct = U/(norm+1e-6)
        rbf = rbf_func(norm[...,0], num_rbf)
        return torch.cat([direct, rbf], dim=-1)
    

    B, L = X.shape[:2]
    diffX = F.pad(X.reshape(-1,3).diff(dim=0), (0,0,1,0)).reshape(B, L, -1, 3)


    V_prev = T[...,None].invert()._rots.apply(F.pad(diffX[:,1:],(0,0,0,0,1,0)))
    V_current = T[...,None].invert()._rots.apply(diffX)
    V_next = T[...,None].invert()._rots.apply(F.pad(diffX[:,:-1],(0,0,0,0,0,1)))
    V_next2 = T[...,None].invert()._rots.apply(F.pad(diffX[:,:-2],(0,0,0,0,0,2)))
    V = torch.cat([ decouple(V_prev).reshape(B, L, -1),
                    decouple(V_current).reshape(B, L, -1),
                    decouple(V_next).reshape(B, L, -1),
                    decouple(V_next2).reshape(B, L, -1)
                    ], dim=-1)
    V[torch.isnan(V)] = 0



    '''X [N,4,3]: N个氨基酸, 每个氨基酸4个原子(N,CA,C,O), 3是原子的xyz坐标
        T [N]: N个局部坐标系
    '''
    X_local = T[...,None].invert().apply(X)
    diffE = torch.cat([X_local[:,:,None].repeat(1,1,L,1,1),
                       T_ij[...,None].apply(X_local[:,None].repeat(1,L,1,1,1))],
                       dim=-2)
    diffE = decouple(diffE).reshape(B,L,L, -1)
    
    pos = torch.arange(L, device=V.device)
    pos = pos[None,...].repeat(B,1)
    pos_embed = positional_embeddings_transformer(pos[:,:,None]-pos[None,:,:], 16)

    E_quant = T_ij._rots._rot_mats.reshape(B,L,L,9)
    E_trans = T_ij._trans
    E_trans = decouple(E_trans).reshape(B,L,L,-1)
    E = torch.cat([diffE, E_quant, E_trans, pos_embed], dim=-1)
    return {'_V':V.to(torch.float), '_E':E.to(torch.float)}


