import os
import pickle
from typing import List, Optional, Tuple, Union

import lmdb
import numpy as np
import torch
from einops import einsum
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Geometry import Point3D
from scipy.optimize import linear_sum_assignment
from scipy.spatial.transform import Rotation
from torch.types import Device
from torch_geometric.data import Data
from shepherd_score.extract_profiles import get_pharmacophores

from src.constants import BUILDING_BLOCKS, COORDS_STD, MAX_ATOMS, REACTIONS
from src.utils.indexing_utils import (
    idx_to_smarts,
    idx_to_smiles,
    onehot_to_node_indices,
    onehot_to_reaction_type_and_centers,
    pharm_indices_to_onehot
)


def leaving_atom_mask(X: torch.Tensor, E: torch.Tensor) -> torch.Tensor:
    """
    X: [n_nodes, n_nodes, n_building_blocks]
    E: [n_nodes, n_nodes, n_reactions * n_centers * n_centers + 2]
    """

    # Initialize mask
    mask = torch.zeros(X.shape[0], MAX_ATOMS, 3, dtype=torch.bool)

    # Get node indices and reaction type and centers
    node_indices = onehot_to_node_indices(X).tolist()
    reaction_type_and_centers = onehot_to_reaction_type_and_centers(E).tolist()
    # print(node_indices, reaction_type_and_centers)

    for i, edge in enumerate(reaction_type_and_centers):
        reaction = idx_to_smarts(edge[0])
        node1_order, node2_order, center1_idx, center2_idx = edge[1:]
        building_block_1 = idx_to_smiles(node_indices[node1_order])
        building_block_2 = idx_to_smiles(node_indices[node2_order])
        r1_atom_dropped = REACTIONS[reaction][0]
        r2_atom_dropped = REACTIONS[reaction][1]

        # Mask atoms that are dropped during reactions
        if r1_atom_dropped:
            r1_atom_idx = BUILDING_BLOCKS[building_block_1][center1_idx]
            mask[node1_order, r1_atom_idx] = True
        if r2_atom_dropped:
            r2_atom_idx = BUILDING_BLOCKS[building_block_2][center2_idx]
            mask[node2_order, r2_atom_idx] = True

    # TODO: Some of this is definitely redundant
    for i in range(X.shape[0]):
        mol = Chem.MolFromSmiles(idx_to_smiles(node_indices[i]))
        n_atoms = mol.GetNumAtoms()
        mask[i, n_atoms:] = True

    return mask


def select_conformer(molecule_id: int, conformers_path: str, random: bool = True) -> str:
    conf_idx = 0
    if conformers_path.endswith(".lmdb"):
        conf_keys = []
        lmdb_env = lmdb.open(conformers_path, readonly=True, lock=False)
        with lmdb_env.begin() as txn:
            while True:
                key = f"mol_{molecule_id}_final_conf_{conf_idx}"
                if txn.get(key.encode()) is not None:
                    conf_keys.append(key)
                    conf_idx += 1
                else:
                    break

        if not conf_keys:
            lmdb_env.close()
            raise ValueError(f"No conformers found for molecule {molecule_id}")

        if random:
            selected_key = np.random.choice(conf_keys)
        else:
            selected_key = conf_keys[0]

        with lmdb_env.begin() as txn:
            value = txn.get(selected_key.encode())
            if value is None:
                lmdb_env.close()
                raise KeyError(f"Key {selected_key} not found in the database")
            value = pickle.loads(value)

        lmdb_env.close()
        return value
    else:
        conf_files = []
        while True:
            conf_path = os.path.join(
                conformers_path, f"mol_{molecule_id}_final_conf_{conf_idx}.xyz"
            )
            if os.path.exists(conf_path):
                conf_files.append(conf_path)
                conf_idx += 1
            else:
                break

        if random and len(conf_files) > 0:
            return conf_files[np.random.randint(len(conf_files))]
        else:
            return conf_files[0]


def get_pharmacophore(key: str, pharm_path: str, n_subset: int = None):
    """
    Load pharmacophore data and optionally subset to n_subset pharmacophores.

    Args:
        key: str, LMDB key
        pharm_path: str, path to LMDB file
        n_subset: int, if provided and len(types) > n_subset,
                  randomly subset pharmacophores consistently

    Returns:
        types_padded: (n_subset, num_types) float tensor
        pos_padded:   (n_subset, dim_pos) float tensor
        pharm_padding_mask: (n_subset,) float tensor
    """
    if pharm_path.endswith(".lmdb"):
        lmdb_env = lmdb.open(pharm_path, readonly=True, lock=False)
        with lmdb_env.begin() as txn:
            value = txn.get(key.encode("utf-8"))
            data = pickle.loads(value)
            
            types = torch.tensor(data["types"])
            pos = torch.tensor(data["pos"], dtype=torch.float32)

            # If requested, subset pharmacophores
            if n_subset is not None and len(types) > n_subset:
                indices = torch.randperm(len(types))[:n_subset]
                types = types[indices]
                pos = pos[indices]

            pharm_ones = torch.ones_like(types, dtype=torch.float32)
            types_onehot = pharm_indices_to_onehot(types)

            types_padded = torch.zeros((n_subset, types_onehot.shape[1]), dtype=torch.float32)
            pharm_padding_mask = torch.zeros(n_subset, dtype=torch.float32)
            pos_padded = torch.zeros((n_subset, pos.shape[1]), dtype=torch.float32)

            pharm_padding_mask[:len(pharm_ones)] = pharm_ones
            types_padded[:, -1] = 1  # default "padding class"
            types_padded[:len(types)] = types_onehot
            pos_padded[:len(pos)] = pos

            return types_padded, pos_padded, pharm_padding_mask


def xyz_to_coordinates(
    conf_xyz: str, smiles: str, X: torch.Tensor, E: torch.Tensor, mask_value: float = 0.0
) -> torch.Tensor:
    """
    Efficiently parse XYZ file / XYZ text into fragment-grouped coordinates

    Args:
        conf_xyz: XYZ file or XYZ text
        smiles: SMILES string of the molecule
        X: [n_nodes, n_building_blocks]
        E: [n_nodes, n_nodes, n_reactions * n_centers * n_centers + 2]
        mask_value: Value to use for masked atoms
    Returns:
        coords_tensor: torch.Tensor of shape [n_fragments, max_atoms_per_fragment, 3]
        coords_mask: torch.Tensor of shape [n_fragments, max_atoms_per_fragment]
    """

    if conf_xyz.endswith(".xyz"):
        with open(conf_xyz) as f:
            lines = f.readlines()
    else:
        lines = conf_xyz.split("\n")

    n_atoms = Chem.MolFromSmiles(smiles).GetNumAtoms()

    # Extract SMILES_noHs from second line and verify it matches
    header = lines[1].strip()
    xyz_smiles_no_hs = header.split("SMILES_noHs: ")[1].split(";")[0]
    # Remove stereochemistry using RDKit
    mol = Chem.MolFromSmiles(xyz_smiles_no_hs)
    Chem.RemoveStereochemistry(mol)
    xyz_smiles_no_hs = Chem.MolToSmiles(mol)
    assert xyz_smiles_no_hs == smiles, f"SMILES mismatch: {xyz_smiles_no_hs} != {smiles}"

    # Get leaving atom mask
    lam = leaving_atom_mask(X, E)
    coords_tensor = torch.full((X.shape[0], MAX_ATOMS, 3), mask_value)  # Initialize with mask_value
    coords_mask = torch.zeros((X.shape[0], MAX_ATOMS), dtype=torch.bool)

    # Keep track of atom indices per fragment
    curr_atom = 0
    curr_frag = 0
    offset = 0
    for line in lines[2 : 2 + n_atoms]:
        parts = line.split()
        if len(parts) < 4 or parts[0][0] == "H":  # Skip if insufficient parts or hydrogen
            continue

        # Extract fragment index from the second column (e.g., "54_0" -> "0")
        frag_idx = int(parts[0].split("_")[-1])

        # Get current atom index for this fragment
        if frag_idx > curr_frag:
            curr_frag = frag_idx
            curr_atom = 0
            offset = 0

        # Skip indices where there are leaving atoms
        while lam[frag_idx, curr_atom + offset].any():
            offset += 1

        # Store coordinates
        coords_tensor[frag_idx, curr_atom + offset] = torch.tensor([float(x) for x in parts[1:4]])
        coords_mask[frag_idx, curr_atom + offset] = True
        curr_atom += 1

    return coords_tensor, coords_mask


def coordinates_to_mol(
    X: torch.Tensor,
    E: torch.Tensor,
    coords: torch.Tensor,
    mask_value: float = 0.0,
    padding_at_end: bool = False,
) -> Chem.Mol:
    """Create a new RDKit molecule from coordinates, bonds and atom sequence.

    Args:
        X: [n_nodes, n_building_blocks]
        E: [n_nodes, n_nodes, n_reactions * n_centers * n_centers + 2]
        coords: [n_fragments, max_atoms_per_fragment, 3] torch tensor containing 3D coordinates
               for each atom in each fragment. Coordinates for masked atoms are set to mask_value.
        mask_value: Value used to indicate masked coordinates (default: 0.0)
        padding_at_end: Whether coordinates are padded at the end of the fragment

    Returns:
        Chem.Mol: New RDKit molecule with 3D coordinates
    """
    # Get bonds and atom sequence
    bonds = get_bonds(X, E)
    atom_sequence = get_atom_sequence(X, E)

    # Mask the coordinates and flatten
    if not padding_at_end:
        lam = leaving_atom_mask(X, E).to(coords.device)
        coords = coords * ~lam
        flat_coords = []
        for frag in coords:
            for atom in frag:
                if (atom != mask_value).all():
                    flat_coords.append(atom)
    else:
        flat_coords = coords

    # Create empty editable mol
    mol = Chem.RWMol()

    # Add atoms
    for atom in atom_sequence:
        mol.AddAtom(atom)

    # Add bonds
    for begin_idx, end_idx, bond_type in bonds:
        mol.AddBond(begin_idx, end_idx, bond_type)

    # Create conformer and set coordinates
    conf = Chem.Conformer(len(flat_coords))
    for i, coord in enumerate(flat_coords):
        conf.SetAtomPosition(i, Point3D(float(coord[0]), float(coord[1]), float(coord[2])))

    # Convert to immutable mol and add conformer
    mol = mol.GetMol()
    mol.AddConformer(conf)

    # Sanitize and kekulize the molecule
    try:
        Chem.SanitizeMol(mol, sanitizeOps=Chem.SANITIZE_ALL)
        Chem.Kekulize(mol)
    except Exception as e:
        print(f"Failed to sanitize/kekulize molecule: {e}")
        return None

    return mol


def get_bond_lengths(
    coords: torch.Tensor,
    bonds: List[Tuple[int, int, Chem.BondType]],
    mask_value: float = 0.0,
) -> torch.Tensor:
    """Calculate bond lengths from coordinates.

    Args:
        coords: [n_fragments, max_atoms_per_fragment, 3] torch tensor containing 3D coordinates
               for each atom in each fragment. Coordinates for masked atoms are set to mask_value.
        bonds: List of tuples, each containing (atom1_idx, atom2_idx, bond_type)
        mask_value: Value used to indicate masked coordinates (default: 0.0)

    Returns:
        torch.Tensor: Tensor containing bond lengths for each bond
    """
    # Calculate bond lengths for all bonds
    flat_coords = coords.reshape(-1, 3)
    bond_lengths = []
    for begin_idx, end_idx, _ in bonds:
        begin_coords = flat_coords[begin_idx]
        end_coords = flat_coords[end_idx]
        # Calculate Euclidean distance between atoms
        bond_length = torch.sqrt(torch.sum((begin_coords - end_coords) ** 2))
        bond_lengths.append(bond_length)

    return torch.stack(bond_lengths)


def get_bonds(
    X: torch.Tensor, E: torch.Tensor, reindex: bool = True, as_onehot_adj_tensor: bool = False
) -> Union[List[Tuple[int, int, Chem.BondType]], torch.Tensor]:
    """Get bonds from fragment order information.

    Args:
        X: [n_nodes, n_building_blocks]
        E: [n_nodes, n_nodes, n_reactions * n_centers * n_centers + 2]
        reindex: Whether to reindex atoms to ascending order starting from 0 (default: True)
        as_onehot_adj_tensor: Whether to return bonds as a one-hot adjacency tensor (default: False)

    Returns:
        Either:
        - List of tuples, each containing (atom1_idx, atom2_idx, bond_type)
        - Tensor of shape [n_atoms, n_atoms, 5] containing one-hot encoded bond types:
          [single, double, triple, aromatic, is_masked]
    """
    # Get flattened coordinates like in coordinates_to_mol
    lam = leaving_atom_mask(X, E).to(X.device)
    flat_mask = lam.reshape(X.shape[0] * MAX_ATOMS, 3)[:, 0]
    X_indices = onehot_to_node_indices(X)
    E_indices = onehot_to_reaction_type_and_centers(E)
    global_idx = 0
    bonds = []
    for frag_idx in X_indices:
        mol = Chem.MolFromSmiles(idx_to_smiles(frag_idx))
        for bond in mol.GetBonds():
            begin_idx = bond.GetBeginAtomIdx()
            end_idx = bond.GetEndAtomIdx()
            begin_idx += global_idx
            end_idx += global_idx
            if ~flat_mask[begin_idx] and ~flat_mask[end_idx]:
                bond_type = bond.GetBondType()
                bonds.append((begin_idx, end_idx, bond_type))

        global_idx += MAX_ATOMS

    for reaction_idx, begin_idx, end_idx, center1_idx, center2_idx in E_indices:
        building_block_1 = idx_to_smiles(X_indices[begin_idx])
        building_block_2 = idx_to_smiles(X_indices[end_idx])
        center_1 = BUILDING_BLOCKS[building_block_1][center1_idx]
        center_2 = BUILDING_BLOCKS[building_block_2][center2_idx]
        reaction = idx_to_smarts(reaction_idx)
        r1_atom_dropped = REACTIONS[reaction][0]
        r2_atom_dropped = REACTIONS[reaction][1]

        center_atom_idx1 = center_1 + MAX_ATOMS * begin_idx
        center_atom_idx2 = center_2 + MAX_ATOMS * end_idx

        if r1_atom_dropped:
            # Get the index of the atom connected to the reaction center
            mol1 = Chem.MolFromSmiles(building_block_1)
            atom = mol1.GetAtomWithIdx(center_1)
            neighbor = atom.GetNeighbors()[0]  # Get first neighbor since leaving group is removed
            center_atom_idx1 = neighbor.GetIdx() + MAX_ATOMS * begin_idx
        if r2_atom_dropped:
            mol2 = Chem.MolFromSmiles(building_block_2)
            atom = mol2.GetAtomWithIdx(center_2)
            neighbor = atom.GetNeighbors()[0]  # Get first neighbor since leaving group is removed
            center_atom_idx2 = neighbor.GetIdx() + MAX_ATOMS * end_idx

        # Assuming single bond for reaction bonds
        bonds.append((center_atom_idx1.item(), center_atom_idx2.item(), Chem.BondType.SINGLE))

    if reindex:
        atoms = set()
        for a, b, _ in bonds:
            atoms.add(a)
            atoms.add(b)

        # Sort atoms and create a mapping to new indices (starting at 0)
        sorted_atoms = sorted(atoms)
        atom_mapping = {old: new for new, old in enumerate(sorted_atoms)}

        bonds = [(atom_mapping[a], atom_mapping[b], bond_type) for a, b, bond_type in bonds]

    if as_onehot_adj_tensor:
        # Create adjacency tensor
        adj = torch.zeros((5 * MAX_ATOMS, 5 * MAX_ATOMS, 5), device=X.device)
        for a1, a2, bond_type in bonds:
            # Convert bond type to one-hot index
            if bond_type == Chem.BondType.SINGLE:
                idx = 0
            elif bond_type == Chem.BondType.DOUBLE:
                idx = 1
            elif bond_type == Chem.BondType.TRIPLE:
                idx = 2
            elif bond_type == Chem.BondType.AROMATIC:
                idx = 3

            # Set bond features and unset masked flag
            adj[a1, a2, idx] = 1
            adj[a2, a1, idx] = 1

        return adj

    return bonds


def get_atom_sequence(X: torch.Tensor, E: torch.Tensor) -> list:
    """Get the atom sequence from the fragment order information."""
    lam = leaving_atom_mask(X, E).to(X.device)
    flat_mask = lam.reshape(X.shape[0] * MAX_ATOMS, 3)[:, 0]
    X_indices = onehot_to_node_indices(X)
    global_idx = 0
    atom_sequence = []
    for frag_idx in X_indices:
        mol = Chem.MolFromSmiles(idx_to_smiles(frag_idx))
        atom_sequence.extend(
            [atom for atom in mol.GetAtoms() if not flat_mask[atom.GetIdx() + global_idx]]
        )
        global_idx += MAX_ATOMS
    return atom_sequence


def save_as_sdf(mol: Chem.Mol, filename: str):
    """Save coordinates as SDF file."""
    writer = Chem.SDWriter(filename)
    writer.write(mol)
    writer.close()


def weighted_rigid_align(
    from_coords: torch.Tensor,
    to_coords: torch.Tensor,
    weights: Optional[torch.Tensor] = None,
    mask: Optional[torch.Tensor] = None,
):
    """Compute weighted alignment.

    Parameters
    ----------
    from_coords: torch.Tensor               # B x N x 3
        The coordinates to be aligned
    to_coords: torch.Tensor               # B x N x 3
        The target coordinates to align to
    weights: torch.Tensor                   # B x N
        The weights for alignment
    mask: torch.Tensor                      # B x N
        The atoms mask

    Returns
    -------
    torch.Tensor
        Aligned coordinates
    """
    batch_size, num_points, dim = from_coords.shape

    if weights is None:
        weights = torch.ones_like(from_coords[..., 0])
    if mask is not None:
        weights = mask * weights

    weights = weights.unsqueeze(-1).expand(-1, -1, 3)
    # Compute weighted centroids
    to_centroid = (to_coords * weights).sum(dim=1, keepdim=True) / weights.sum(dim=1, keepdim=True)
    from_centroid = (from_coords * weights).sum(dim=1, keepdim=True) / weights.sum(
        dim=1, keepdim=True
    )

    # Center the coordinates
    to_coords_centered = to_coords - to_centroid
    from_coords_centered = from_coords - from_centroid
    if num_points < (dim + 1):
        print(
            "Warning: The size of one of the point clouds is <= dim+1. "
            + "`WeightedRigidAlign` cannot return a unique rotation."
        )

    # Compute the weighted covariance matrix
    cov_matrix = einsum(weights * from_coords_centered, to_coords_centered, "b n i, b n j -> b i j")
    # Compute the SVD of the covariance matrix, required float32 for svd and determinant
    original_dtype = cov_matrix.dtype
    cov_matrix_32 = cov_matrix.to(dtype=torch.float32)
    U, S, V = torch.linalg.svd(cov_matrix_32, driver="gesvd" if cov_matrix_32.is_cuda else None)
    V = V.mH
    # Catch ambiguous rotation by checking the magnitude of singular values
    if (S.abs() <= 1e-15).any() and not (num_points < (dim + 1)):
        print(
            "Warning: Excessively low rank of "
            + "cross-correlation between aligned point clouds. "
            + "`WeightedRigidAlign` cannot return a unique rotation."
        )

    # Compute the rotation matrix
    rot_matrix = torch.einsum("b i j, b k j -> b i k", U, V).to(dtype=torch.float32)
    # Ensure proper rotation matrix with determinant 1
    F = torch.eye(dim, dtype=cov_matrix_32.dtype, device=cov_matrix.device)[None].repeat(
        batch_size, 1, 1
    )
    F[:, -1, -1] = torch.det(rot_matrix)
    rot_matrix = einsum(U, F, V, "b i j, b j k, b l k -> b i l")
    rot_matrix = rot_matrix.to(dtype=original_dtype)

    # Apply the rotation and translation
    aligned_coords = einsum(from_coords_centered, rot_matrix, "b n i, b i j -> b n j") + to_centroid
    aligned_coords.detach_()

    return aligned_coords


def augment_coordinates(
    coords: torch.Tensor,
    coords_masks: torch.Tensor,
    *,  # Force keyword arguments for clarity
    pharm_coords: Optional[torch.Tensor] = None,
    pharm_masks: Optional[torch.Tensor] = None,
    center: bool = True,
    normalize: bool = False,
    align: bool = True,
    rotate: bool = True,
    translate: bool = True,
    translation_scale: float = 1.0,
    return_second_coords: bool = False,
    second_coords: Optional[torch.Tensor] = None,
    reference_coords: Optional[torch.Tensor] = None,
    conf_dir: str = None,
) -> Union[
    torch.Tensor,
    Tuple[torch.Tensor, torch.Tensor],
    Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
]:
    """Augment atomic coordinates with centering, rotation, and translation.

    Parameters
    ----------
    coords : torch.Tensor
        The atomic coordinates to augment
    coords_masks : torch.Tensor
        Binary masks indicating valid atoms
    pharm_coords : Optional[torch.Tensor]
        Pharmacophore coordinates to include in augmentation
    pharm_masks : Optional[torch.Tensor]
        Binary masks indicating valid pharmacophores (0 = padded)
    center : bool
        Whether to center coordinates at origin
    align : bool
        Whether to align coordinates to reference coordinates
    rotate : bool
        Whether to apply random rotation
    translate : bool
        Whether to apply random translation
    translation_scale : float
        Scale factor for random translations
    return_second_coords : bool
        Whether to return transformed second_coords
    second_coords : Optional[torch.Tensor]
        Optional second set of coordinates to transform
    reference_coords : Optional[torch.Tensor]
        Reference coordinates to align to if align=True
    conf_dir : Optional[str]
        Directory containing conformer files

    Returns
    -------
    Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
        If no pharm_coords: augmented atom coords (and optionally second_coords).
        If pharm_coords and return_split=True:
            (atom_coords, pharm_coords) or
            (atom_coords, pharm_coords, atom_second, pharm_second).
    """
    n_atoms = coords.size(1)
    n_pharms = pharm_coords.size(1) if pharm_coords is not None else 0

    # Concatenate pharmacophores if provided
    if pharm_coords is not None:
        coords = torch.cat([coords, pharm_coords], dim=1)
        if pharm_masks is None:
            pharm_masks = torch.zeros(
                coords.size(0), n_pharms,
                device=coords.device, dtype=coords_masks.dtype
            )
        coords_nopharm_masks = torch.cat([coords_masks, torch.zeros_like(pharm_masks)], dim=1)
        coords_masks = torch.cat([coords_masks, pharm_masks], dim=1)

    if align:
        coords = weighted_rigid_align(coords, reference_coords, mask=coords_masks)
        coords = coords * coords_masks[:, :, None]

    if center:
        if pharm_coords is not None:
            atom_mean = torch.sum(coords * coords_nopharm_masks[:, :, None], dim=1, keepdim=True) / torch.sum(
                coords_nopharm_masks[:, :, None], dim=1, keepdim=True
            )
        else:
            atom_mean = torch.sum(coords * coords_masks[:, :, None], dim=1, keepdim=True) / torch.sum(
                coords_masks[:, :, None], dim=1, keepdim=True
            )            
        coords = coords - atom_mean
        coords = coords * coords_masks[:, :, None]
        if second_coords is not None:
            second_coords = second_coords - atom_mean

        if normalize:
            coords = coords / COORDS_STD
            if second_coords is not None:
                second_coords = second_coords / COORDS_STD

    if rotate:
        coords, second_coords = randomly_rotate(
            coords, return_second_coords=True, second_coords=second_coords
        )
        coords = coords * coords_masks[:, :, None]
        if second_coords is not None:
            second_coords = second_coords * coords_masks[:, :, None]

    if translate:
        random_trans = torch.randn_like(coords[:, 0:1, :]) * translation_scale
        coords = (coords + random_trans) * coords_masks[:, :, None]
        if second_coords is not None:
            second_coords = (second_coords + random_trans) * coords_masks[:, :, None]

    # Split back
    if pharm_coords is not None:
        atom_coords = coords[:, :n_atoms, :]
        pharm_coords_out = coords[:, n_atoms:, :]
        if return_second_coords and second_coords is not None:
            atom_second = second_coords[:, :n_atoms, :]
            pharm_second = second_coords[:, n_atoms:, :]
            return atom_coords, pharm_coords_out, atom_second, pharm_second
        return atom_coords, pharm_coords_out

    if return_second_coords:
        return coords, second_coords

    return coords


def randomly_rotate(coords, return_second_coords=False, second_coords=None):
    R = random_rotations(len(coords), coords.dtype, coords.device)

    if return_second_coords:
        return torch.einsum("bmd,bds->bms", coords, R), (
            torch.einsum("bmd,bds->bms", second_coords, R) if second_coords is not None else None
        )

    return torch.einsum("bmd,bds->bms", coords, R)


def random_rotations(
    n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
) -> torch.Tensor:
    """
    Generate random rotations as 3x3 rotation matrices.

    Args:
        n: Number of rotation matrices in a batch to return.
        dtype: Type to return.
        device: Device of returned tensor. Default: if None,
            uses the current device for the default tensor type.

    Returns:
        Rotation matrices as tensor of shape (n, 3, 3).
    """
    quaternions = random_quaternions(n, dtype=dtype, device=device)
    return quaternion_to_matrix(quaternions)


def random_quaternions(
    n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
) -> torch.Tensor:
    """
    Generate random quaternions representing rotations,
    i.e. versors with nonnegative real part.

    Args:
        n: Number of quaternions in a batch to return.
        dtype: Type to return.
        device: Desired device of returned tensor. Default:
            uses the current device for the default tensor type.

    Returns:
        Quaternions as tensor of shape (N, 4).
    """
    if isinstance(device, str):
        device = torch.device(device)
    o = torch.randn((n, 4), dtype=dtype, device=device)
    s = (o * o).sum(1)
    o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
    return o


def _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    """
    Return a tensor where each element has the absolute value taken from the,
    corresponding element of a, with sign taken from the corresponding
    element of b. This is like the standard copysign floating-point operation,
    but is not careful about negative 0 and NaN.

    Args:
        a: source tensor.
        b: tensor whose signs will be used, of the same shape as a.

    Returns:
        Tensor of the same shape as a with the signs of b.
    """
    signs_differ = (a < 0) != (b < 0)
    return torch.where(signs_differ, -a, a)


def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
    """
    Convert rotations given as quaternions to rotation matrices.

    Args:
        quaternions: quaternions with real part first,
            as tensor of shape (..., 4).

    Returns:
        Rotation matrices as tensor of shape (..., 3, 3).
    """
    r, i, j, k = torch.unbind(quaternions, -1)
    # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
    two_s = 2.0 / (quaternions * quaternions).sum(-1)

    o = torch.stack(
        (
            1 - two_s * (j * j + k * k),
            two_s * (i * j - k * r),
            two_s * (i * k + j * r),
            two_s * (i * j + k * r),
            1 - two_s * (i * i + k * k),
            two_s * (j * k - i * r),
            two_s * (i * k - j * r),
            two_s * (j * k + i * r),
            1 - two_s * (i * i + j * j),
        ),
        -1,
    )
    return o.reshape(quaternions.shape[:-1] + (3, 3))


def center_atom_coords(atom_coords: torch.Tensor, atom_mask: torch.Tensor):
    """Removes the mean of the atom coordinates, effectively centering them.
    Args:
        atom_coords: Tensor of shape [*, N, 3]
        atom_mask: Tensor of shape [*, N]
    Returns:
        Tensor of shape [*, N, 3]
    """
    atom_mask = atom_mask.bool().to(atom_coords.device)

    # Calculate the mean of the atom coordinates, considering the mask
    atom_mean = torch.sum(atom_coords * atom_mask.unsqueeze(-1), dim=-2, keepdim=True) / torch.sum(
        atom_mask.unsqueeze(-1), dim=-2, keepdim=True
    )

    # Subtract the mean from the coordinates to center them
    atom_coords = atom_coords - atom_mean
    return atom_coords


def get_batch_reference_coords(
    data: List[Data], conf_dir: str, max_length: int, mask_value: float = 0.0
):
    """Get reference coordinates for a batch of molecules during training.

    Used during training to get reference conformer coordinates for each molecule in a batch from the dataloader.
    The coordinates are used for spatial augmentations like alignment and centering.
    """
    batch_size = len(data)
    # Initialize output tensors with mask_value
    reference_coords = torch.full(
        (batch_size, max_length * MAX_ATOMS, 3), mask_value, device=data[0].x.device
    )
    reference_masks = torch.zeros(
        (batch_size, max_length * MAX_ATOMS), dtype=torch.bool, device=data[0].x.device
    )

    for batch_idx, data_item in enumerate(data):
        molecule_id = data_item.data_index.item()
        # Get 0th conformer path
        reference_conf = select_conformer(molecule_id, conf_dir, random=False)

        # Get coordinates and mask for this conformer
        coords, mask = xyz_to_coordinates(
            reference_conf,
            data_item.smiles,
            data_item.x,
            data_item.edge_attr.reshape(data_item.x.size(0), data_item.x.size(0), -1),
            mask_value,
        )

        # Reshape to flattened form
        frags_in_mol = coords.shape[0]
        flat_coords = coords.reshape(frags_in_mol * MAX_ATOMS, 3)
        flat_mask = mask.reshape(frags_in_mol * MAX_ATOMS)

        # Copy into pre-allocated tensors up to max_length
        reference_coords[batch_idx, : frags_in_mol * MAX_ATOMS] = flat_coords
        reference_masks[batch_idx, : frags_in_mol * MAX_ATOMS] = flat_mask

    return reference_coords, reference_masks


def calc_energy(mol: Chem.Mol, per_atom: bool = False) -> float:
    """Calculate the energy for an RDKit molecule using the MMFF forcefield
    from https://github.com/rssrwn/semla-flow
    The energy is only calculated for the first (0th index) conformer within the molecule. The molecule is copied so
    the original is not modified.

    Args:
        mol (Chem.Mol): RDKit molecule
        per_atom (bool): Whether to normalise by number of atoms in mol, default False

    Returns:
        float: Energy of the molecule or None if the energy could not be calculated
    """

    try:
        mmff_props = AllChem.MMFFGetMoleculeProperties(mol, mmffVariant="MMFF94")
        ff = AllChem.MMFFGetMoleculeForceField(mol, mmff_props, confId=0)
        energy = ff.CalcEnergy()
        energy = energy / mol.GetNumAtoms() if per_atom else energy
    except Exception as e:
        print(f"Failed to calculate energy: {e}")
        energy = None

    return energy


def inter_distances(coords1, coords2, sqrd=False, eps=1e-6):
    unbatched = False
    if coords1.ndim == 2:
        # Add a batch dimension: now shape becomes [1, N, 3]
        coords1 = coords1.unsqueeze(0)
        coords2 = coords2.unsqueeze(0)
        unbatched = True

    # Calculate pairwise distances using cdist
    dists = torch.cdist(coords1, coords2)  # shape: [B, N, N]

    if sqrd:
        result = dists.pow(2)
    else:
        result = dists

    if unbatched:
        return result.squeeze(0)  # remove the added batch dimension -> [N, N]
    else:
        return result


def align_and_permute(C0, C1, coords_mask, sqrd=True):
    """
    Aligns C1 to C0 by computing an optimal assignment (permutation) and rotation.

    Args:
        C0 (torch.Tensor): Target coordinates of shape [B, n*m, 3].
        C1 (torch.Tensor): Coordinates to be aligned of shape [B, n*m, 3].
        coords_mask (torch.Tensor): Mask of shape [B, n*m].
        sqrd (bool): If True, use squared Euclidean distances in the cost matrix.

    Returns:
        torch.Tensor: The aligned (permuted and rotated) version of C1 with shape [B, n*m, 3].
    """
    # B, n_atoms, _ = C0.shape
    # # Compute optimal assignment for each batch
    # for b in range(B):
    #     # Only consider valid points
    #     valid_mask = coords_mask[b]
    #     valid_C0 = C0[b][valid_mask]
    #     valid_C1 = C1[b][valid_mask]
    #     cost_matrix = inter_distances(valid_C0, valid_C1, sqrd=sqrd)
    #     row_ind, col_ind = linear_sum_assignment(cost_matrix.cpu().numpy())

    #     # Map the permutation back to the full tensor
    #     valid_indices = torch.where(valid_mask)[0]
    #     permuted_indices = valid_indices[col_ind]  # real atom indices in C1

    #     # Assign permuted atoms back into the original coordinate tensor
    #     C1[b][valid_indices] = C1[b][permuted_indices]

    # Use weighted_rigid_align for rotation
    aligned_C1 = weighted_rigid_align(C1, C0, mask=coords_mask)
    return aligned_C1


def bond_length_loss(C0, C1, bonds, sqrd=False):
    """
    Calculate the bond length loss between two sets of coordinates by comparing bond lengths.

    Args:
        C0: Reference coordinates tensor of shape [n_fragments, max_atoms_per_fragment, 3]
        C1: Predicted coordinates tensor of shape [n_fragments, max_atoms_per_fragment, 3]
        bonds: List of tuples containing (atom1_idx, atom2_idx, bond_type)
        sqrd: If True, return squared differences instead of absolute differences

    Returns:
        torch.Tensor: Mean bond length loss between C0 and C1
    """
    # Get bond lengths for both coordinate sets
    lengths0 = get_bond_lengths(C0, bonds).to(C0.device)
    lengths1 = get_bond_lengths(C1, bonds).to(C0.device)

    # Calculate differences between bond lengths
    if sqrd:
        diff = (lengths0 - lengths1) ** 2
    else:
        diff = torch.abs(lengths0 - lengths1)
    # Return mean of differences
    return torch.mean(diff)


def pairwise_distance_loss(C0, C1, coords_mask, threshold=5.0, sqrd=False):
    """Calculate loss based on pairwise distances between atoms.

    Args:
        C0: Ground truth coordinates tensor of shape [bs, n_fragments, max_atoms_per_fragment, 3]
        C1: Predicted coordinates tensor of shape [bs, n_fragments, max_atoms_per_fragment, 3]
        coords_mask: Binary mask indicating valid atoms of shape [bs, n_fragments, max_atoms_per_fragment]
        sqrd: If True, use squared distances instead of Euclidean distances

    Returns:
        torch.Tensor: Mean pairwise distance loss between valid atoms that are within
                     5 Angstroms of each other in the ground truth structure.
                     Returns a tensor of shape [batch_size, 1] containing per-batch losses.
    """

    # Add batch dimension if not present
    if C0.dim() == 3:
        C0 = C0.unsqueeze(0)
        C1 = C1.unsqueeze(0)
        coords_mask = coords_mask.unsqueeze(0)
    C0 = C0.reshape(C0.shape[0], -1, 3)  # [bs, n_atoms, 3]
    C1 = C1.reshape(C1.shape[0], -1, 3)  # [bs, n_atoms, 3]
    coords_mask = coords_mask.reshape(C0.shape[0], -1)  # [bs, n_atoms]
    pairwise_mask = torch.stack(
        [torch.outer(coords_mask[i], coords_mask[i]) for i in range(C0.shape[0])]
    )  # [bs, n_atoms, n_atoms]

    # Calculate pairwise distances for C0
    C0_dists = inter_distances(C0, C0, sqrd=sqrd)  # [bs, n_atoms, n_atoms]

    # Mask for atom pairs within threshold in ground truth
    close_atoms_mask = (C0_dists < threshold) & pairwise_mask  # [bs, n_atoms, n_atoms]

    # Calculate pairwise distances for C1
    C1_dists = inter_distances(C1, C1, sqrd=sqrd)  # [bs, n_atoms, n_atoms]

    # Calculate loss as absolute difference between C1 and C0 distances
    loss = torch.abs(C1_dists - C0_dists)

    # Mask loss for atoms that are within threshold in ground truth
    masked_loss = loss * close_atoms_mask

    # Sum over atom pairs for each batch
    total = masked_loss.sum(dim=(-2, -1))  # [bs]
    count = close_atoms_mask.sum(dim=(-2, -1))  # [bs]

    # Return loss per batch
    return (total / (count + 1e-6)).unsqueeze(-1)  # [bs, 1]


def smooth_lddt_loss(C0, C1, coords_mask, cutoff=15.0, sqrd=False):
    """Calculate loss based on pairwise distances between atoms using smoothLDDT weighting.

    Args:
        C0: Ground truth coordinates tensor of shape [bs, n_fragments, max_atoms_per_fragment, 3]
        C1: Predicted coordinates tensor of shape [bs, n_fragments, max_atoms_per_fragment, 3]
        coords_mask: Binary mask indicating valid atoms of shape [bs, n_fragments, max_atoms_per_fragment]
        cutoff: Distance cutoff for considering atom pairs (default: 15.0)
        sqrd: If True, use squared distances instead of Euclidean distances
    Returns:
        torch.Tensor: SmoothLDDT loss between valid atoms.
    """
    # Add batch dimension if not present
    if C0.dim() == 3:
        C0 = C0.unsqueeze(0)
        C1 = C1.unsqueeze(0)
        coords_mask = coords_mask.unsqueeze(0)

    C0 = C0.reshape(C0.shape[0], -1, 3)  # [bs, n_atoms, 3]
    C1 = C1.reshape(C1.shape[0], -1, 3)  # [bs, n_atoms, 3]
    coords_mask = coords_mask.reshape(C0.shape[0], -1)  # [bs, n_atoms]

    # Calculate pairwise distances
    true_dists = inter_distances(C0, C0, sqrd=sqrd)  # [bs, n_atoms, n_atoms]
    pred_dists = inter_distances(C1, C1, sqrd=sqrd)  # [bs, n_atoms, n_atoms]

    # Calculate distance differences
    dist_diff = torch.abs(true_dists - pred_dists)

    # SmoothLDDT thresholds
    lddt_thresholds = torch.tensor([0.5, 1.0, 2.0, 4.0], device=C0.device)

    # Calculate epsilon values
    eps = lddt_thresholds.reshape(1, 1, 1, -1) - dist_diff.unsqueeze(
        -1
    )  # [bs, n_atoms, n_atoms, n_thresholds]
    eps = torch.sigmoid(eps).mean(dim=-1)  # [bs, n_atoms, n_atoms]

    # Create inclusion radius mask
    inclusion_mask = true_dists < cutoff

    # Remove self-interactions
    mask = inclusion_mask & ~torch.eye(C0.shape[1], dtype=torch.bool, device=C0.device)[None, :, :]

    # Apply coords mask if provided
    pairwise_mask = torch.stack(
        [torch.outer(coords_mask[i], coords_mask[i]) for i in range(C0.shape[0])]
    )
    mask = mask & pairwise_mask

    # Calculate masked average
    mask_sum = mask.sum(dim=(-2, -1))
    lddt = (eps * mask).sum(dim=(-2, -1)) / (mask_sum + 1e-6)

    return 1.0 - lddt.mean()


def move_padding_to_end(coords: torch.Tensor, coords_mask: torch.Tensor) -> torch.Tensor:
    """
    Move all padded atoms to the end of the molecule.

    Args:
        coords: Coordinates tensor of shape [batch_size, n_atoms * max_atoms_per_fragment, 3]
        coords_mask: Binary mask indicating valid atoms of shape [batch_size, n_atoms * max_atoms_per_fragment]

    Returns:
        torch.Tensor: Coordinates with padding moved to end, same shape as input
    """
    batch_size = coords.shape[0]
    result = []
    for i in range(batch_size):
        valid = coords[i][coords_mask[i]]
        padded = coords[i][~coords_mask[i]]
        result.append(torch.cat([valid, padded], dim=0))
    return torch.stack(result)


def check_batch_coord_means(C: torch.Tensor, atom_mask: torch.Tensor, atol: float = 1e-3) -> bool:
    """
    Check if each molecule in the batch has coordinate mean 0.

    Args:
        C: Coordinates tensor of shape [batch_size, n_atoms, 3]
        atom_mask: Binary mask for valid atoms of shape [batch_size, n_atoms]
        atol: Absolute tolerance for checking if means are 0

    Returns:
        bool: True if all molecules have mean ~0, False otherwise
    """
    batch_size = C.shape[0]
    n_atoms = C.shape[1]

    # Calculate mean coordinates for each molecule
    means = (C.reshape(batch_size, n_atoms, -1) * atom_mask.unsqueeze(-1)).sum(
        dim=1
    ) / atom_mask.sum(dim=1, keepdim=True)

    # Check if means are close to 0
    return torch.allclose(means, torch.zeros_like(means), atol=atol)


def sdf_to_coordinates(sdf_path: str) -> torch.Tensor:
    """
    Convert an SDF file to a tensor of coordinates.

    Args:
        sdf_path: Path to the SDF file

    Returns:
        torch.Tensor: Coordinates tensor of shape [n_atoms, 3]
    """
    mol = Chem.MolFromMolFile(sdf_path)
    Chem.RemoveHs(mol)
    return torch.tensor(mol.GetConformer().GetPositions())


def json_to_pharmacophore(json_path: str, sdf_path: str, length: int) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Convert a JSON file to a tensor of pharmacophore. Pad both of them to length * MAX_ATOMS
    """
    with open(json_path, "r") as f:
        pharm = json.load(f)
    
    pharm_profile = pharm[sdf_path]
    pharm_types = torch.tensor(pharm_profile["types"])
    pharm_coords = torch.tensor(pharm_profile["pos"])
    pharm_types = torch.cat([pharm_types, torch.zeros(length * MAX_ATOMS - pharm_types.shape[0])])
    pharm_coords = torch.cat([pharm_coords, torch.zeros(length * MAX_ATOMS - pharm_coords.shape[0], 3)])

    return pharm_types, pharm_coords


def mol_to_pharm_cond(mol, bs, n_subset, center=True, norm=True):
    """
    Generate pharmacophore conditioning inputs from a molecule, with
    optional random subsampling (different per batch element).
    """
    # Optionally recenter molecule coordinates
    if center:
        conf = mol.GetConformer()
        coords = np.array([list(conf.GetAtomPosition(i)) for i in range(mol.GetNumAtoms())])
        centroid = coords.mean(axis=0)
        for i, pos in enumerate(coords):
            conf.SetAtomPosition(i, pos - centroid)

    types, pos, _ = get_pharmacophores(mol, multi_vector=False)

    if norm:
        pos = pos / COORDS_STD   

    types = torch.tensor(types, dtype=torch.long)
    pos = torch.tensor(pos, dtype=torch.float32)

    # Make one-hot encoding of types
    types_onehot_full = pharm_indices_to_onehot(types)
    n_total = len(types)

    # Storage for outputs (bs, n_subset, *)
    types_padded = torch.zeros((bs, n_subset, types_onehot_full.shape[1]), dtype=torch.float32)
    pos_padded = torch.zeros((bs, n_subset, pos.shape[1]), dtype=torch.float32)
    pharm_padding_mask = torch.zeros((bs, n_subset), dtype=torch.float32)

    for b in range(bs):
        # Decide which indices to keep
        if n_total > n_subset:
            indices = torch.randperm(n_total)[:n_subset]
        else:
            indices = torch.arange(n_total)

        # Subset
        types_onehot = types_onehot_full[indices]
        pos_sel = pos[indices]
        pharm_ones = torch.ones(len(indices), dtype=torch.float32)

        # Fill batch slice
        types_padded[b, :, -1] = 1.0  # default padding class
        types_padded[b, :len(indices)] = types_onehot
        pos_padded[b, :len(indices)] = pos_sel
        pharm_padding_mask[b, :len(indices)] = pharm_ones

    return types_padded, pos_padded, pharm_padding_mask



