import torch
import os
import warnings
from deli import load_json
import copy
import re
import pickle
import numpy as np
from deli import load
from collections import defaultdict
from typing import List, Optional
from torch.utils.data import Dataset
from rdkit.Chem import AllChem, RemoveAllHs
from tqdm import tqdm
from scipy.spatial.transform import Rotation as R
from torch.nn.utils.rnn import pad_sequence


from .complex_dataclasses import Ligand, Protein, Complex, LigandBatch, ProteinBatch, ComplexBatch, BondsBatch
from ..utils.preprocessing import (parse_receptor, read_mols,
                                   extract_receptor_structure_prody, lig_atom_featurizer,
                                   read_molecule, read_sdf_with_multiple_confs,
                                   extract_ligand)
from ..utils.bond_processing import get_rotatable_and_nonrotatable_bonds, split_molecule
from ..utils.transforms import (
    apply_tor_changes_to_pos, get_torsion_angles, find_rigid_alignment, get_bond_properties_for_angles)



def get_ligand_without_randomization(mol_, protein_center=None, remove_hs=True, 
                                     num_new_conformations=1, parse_rotbonds=True):
    """
    Fill the fields of a Ligand object that are not randomized.

    Parameters:
    ----------
    mol_ : rdkit.Chem.Mol
        The input molecule.
    protein_center : numpy.ndarray, optional
        The center of the protein (default is None).
    remove_hs : bool, optional
        Whether to remove hydrogen atoms (default is True).

    Returns:
    -------
    Ligand
        The Ligand object with the filled fields.
    """
    mol_maybe_noh = copy.deepcopy(mol_)

    if remove_hs:
        try:
            mol_maybe_noh = RemoveAllHs(mol_maybe_noh, sanitize=True)
        except Exception as e:
            mol_maybe_noh = RemoveAllHs(mol_maybe_noh, sanitize=False)

    # Ensure the molecule has 3D coordinates
    if not mol_maybe_noh.GetNumConformers():
        AllChem.EmbedMolecule(mol_maybe_noh, randomSeed=13)
        if not mol_maybe_noh.GetNumConformers():
            raise ValueError("Embedding of the molecule failed, unable to generate conformer.")
        
    if parse_rotbonds:
        rotatable_bonds_ext, non_rotatable_bonds_ext, rotatable_bonds, mask_rotate_before_fixing, mask_rotate_after_fixing, bond_periods = get_rotatable_and_nonrotatable_bonds(mol_maybe_noh)
        print('bond_periods', np.round(bond_periods, 2))
        if len(rotatable_bonds) == 0:
            warnings.warn("No rotatable bonds found, but still using the molecule.")

    ligands = {}
    for conf_id in range(num_new_conformations):
        ligand = Ligand() 
        ligand.pos = mol_maybe_noh.GetConformer(conf_id).GetPositions().astype(np.float32) - protein_center

        ligand.orig_mol = mol_maybe_noh  # original mol
        ligand.x = lig_atom_featurizer(mol_maybe_noh)  # features are conformer-invariant
        ligand.final_tr = ligand.pos.mean(0).astype(np.float32).reshape(1, 3)

        if parse_rotbonds:
            # Fill ligand properties
            ligand.rotatable_bonds_ext = copy.deepcopy(rotatable_bonds_ext)
            ligand.non_rotatable_bonds_ext = non_rotatable_bonds_ext
            if len(rotatable_bonds) > 0:
                ligand.rotatable_bonds = rotatable_bonds
                ligand.mask_rotate = mask_rotate_after_fixing
                ligand.mask_rotate_before_fixing = mask_rotate_before_fixing
                ligand.bond_periods = bond_periods
                ligand.init_tor = np.zeros(ligand.rotatable_bonds.shape[0], dtype=np.float32)
                assert ligand.rotatable_bonds.shape[0] == ligand.rotatable_bonds_ext.start.shape[0]
            else:
                ligand.rotatable_bonds = np.array([], dtype=np.int32)
                ligand.mask_rotate = np.array([], dtype=np.int32)
                ligand.mask_rotate_before_fixing = np.array([], dtype=np.int32)
                ligand.init_tor = np.array([], dtype=np.float32)

            ligand.t = None
            ligand.init_rot = np.eye(3, dtype=np.float32).reshape(1, 3, 3)

        ligands[f'conf{conf_id}'] = ligand
    return ligands


def add_ligand_noise(ligand: Ligand):
    num_rotatable_bonds = ligand.rotatable_bonds.shape[0]

    decrease_factor = 1.
    while True:
        pos = np.copy(ligand.pos)
        # Tr:
        pos_mean = pos.mean(axis=0).reshape(1, 3)
        tr = pos_mean # TODO be careful with ligand.pred_tr!!!

        # Generate random rotation matrix close to identity
        # Sample small angles around 0 for each axis
        angle_std = np.pi * decrease_factor / 60
        angles = np.random.normal(0, angle_std, 3).astype(np.float32) # std of 20 degrees
        # Create rotation matrices for each axis
        Rx = R.from_euler('x', angles[0]).as_matrix()
        Ry = R.from_euler('y', angles[1]).as_matrix()
        Rz = R.from_euler('z', angles[2]).as_matrix()
        # Combine rotations
        rot = (Rx @ Ry @ Rz).astype(np.float32)

        # torsion_updates = np.zeros(num_rotatable_bonds).astype(np.float32)
        torsion_updates = np.random.normal(0, angle_std, num_rotatable_bonds).astype(np.float32) # std of 10 degrees

        pos = (pos - pos_mean) @ rot.T + tr
        pos, _ = apply_tor_changes_to_pos(pos, ligand.rotatable_bonds, ligand.mask_rotate,
                                    torsion_updates, is_reverse_order=True)
        
        # set RMSD
        ligand.rmsd = torch.tensor([np.sqrt(((pos - ligand.pos) ** 2).sum(axis=1).sum(axis=0) / ligand.pos.shape[0])]).float()

        if ligand.rmsd < 1:
            break
        decrease_factor /= 2

    # set pos
    ligand.pos = pos


# # small rot tor noise
def randomize_ligand_3stages(ligand: Ligand, tr_std: float, tr_mean: Optional[np.ndarray] = None, stage_num: int = None):
    """
    Randomize the position, rotation, and torsion of a ligand.

    Parameters:
    ----------
    ligand : Ligand
        The input ligand to be randomized.
    tr_std : float
        The standard deviation for translation noise.
    tr_mean : numpy.ndarray, optional
        The mean for translation noise (default is None).

    Returns:
    -------
    None
    """
    pos = np.copy(ligand.pos)

    if stage_num is None:
        stage_num = np.random.randint(1, 4)

    # Tr:
    if tr_mean is None:
        tr_mean = 0.
    
    if stage_num == 1:
        tr = np.random.normal(tr_mean, tr_std, 3).astype(np.float32).reshape(1, 3)
    elif stage_num == 2:
        tr_shift = np.random.normal(0, 2., 3).astype(np.float32).reshape(1, 3)
        tr = ligand.final_tr.reshape(1, 3) + tr_shift
    else:
        tr_shift = np.random.normal(0, 1., 3).astype(np.float32).reshape(1, 3)
        tr = ligand.final_tr.reshape(1, 3) + tr_shift

    # Rot:
    if stage_num == 3:
        noise_scale = 5.
        small_noise = np.random.rand()
        # Generate random rotation matrix close to identity
        # Sample small angles around 0 for each axis
        angle_std = noise_scale * np.pi / (-54 * small_noise + 60)
        angles = np.random.normal(0, angle_std, 3).astype(np.float32) # std of 20 degrees
        # Create rotation matrices for each axis
        Rx = R.from_euler('x', angles[0]).as_matrix()
        Ry = R.from_euler('y', angles[1]).as_matrix()
        Rz = R.from_euler('z', angles[2]).as_matrix()
        # Combine rotations
        rot = (Rx @ Ry @ Rz).astype(np.float32)
    else:
        rot = R.random().as_matrix().astype(np.float32)

    pos_mean = pos.mean(axis=0).reshape(1, 3)
    pos = (pos - pos_mean) @ rot.T + tr.reshape(1, 3)

    # Tor:
    num_rotatable_bonds = ligand.rotatable_bonds.shape[0]
    if num_rotatable_bonds > 0:
        if stage_num == 3:
            torsion_updates = np.random.normal(0, angle_std, num_rotatable_bonds).astype(np.float32) # std of 10 degrees
        else:
            torsion_updates = np.random.uniform(-ligand.bond_periods / 2, ligand.bond_periods / 2)
        pos, _ = apply_tor_changes_to_pos(pos, ligand.rotatable_bonds, ligand.mask_rotate,
                                          torsion_updates, is_reverse_order=True)
    else:
        torsion_updates = np.empty(0).astype(np.float32)

    ligand.init_tr = tr.reshape(1, 3)
    ligand.final_rot = (rot.T)[None, :, :]
    ligand.final_tor = -torsion_updates
    ligand.pos = np.copy(pos)

    # Time is randomized from Uniform[0, 1]:
    ligand.t = torch.rand(1)
    if ligand.rmsd is None:
        ligand.rmsd = torch.zeros(1)
    ligand.stage_num = torch.tensor([stage_num])


def randomize_ligand_with_preds(ligand: Ligand):
    """
    Randomize the position, rotation, and torsion of a ligand.

    Parameters:
    ----------
    ligand : Ligand
        The input ligand to be randomized.

    Returns:
    -------
    None
    """
    pos = np.copy(ligand.orig_pos)

    # Tr:
    tr = ligand.pred_tr.reshape(1, 3)

    # Rot:
    rot = R.random().as_matrix().astype(np.float32)

    # apply predicted rotation and translation
    pos = (pos - pos.mean(axis=0).reshape(1, 3)) @ rot.T + tr.reshape(1, 3)

    # Tor:
    num_rotatable_bonds = ligand.rotatable_bonds.shape[0]
    if num_rotatable_bonds > 0:
        torsion_updates = np.random.uniform(-ligand.bond_periods / 2, ligand.bond_periods / 2)
    else:
        torsion_updates = np.empty(0).astype(np.float32)

    pos, _ = apply_tor_changes_to_pos(pos, ligand.rotatable_bonds, ligand.mask_rotate,
                                                         torsion_updates, is_reverse_order=True)

    ligand.init_tr = tr.reshape(1, 3)
    ligand.final_rot = rot.T[None, :, :]
    ligand.final_tor = -torsion_updates
    ligand.pos = np.copy(pos)

    # Time is randomized from Uniform[0, 1]:
    ligand.t = torch.rand(1)
    if ligand.rmsd is None:
        ligand.rmsd = torch.zeros(1) #torch.tensor([np.sqrt(((pos - ligand.orig_pos) ** 2).sum(axis=1).sum(axis=0) / pos.shape[0])]).float()
    ligand.stage_num = torch.tensor([2])


def set_ligand_data_from_preds(ligand: Ligand, name: str):
    """
    Randomize the position, rotation, and torsion of a ligand.

    Parameters:
    ----------
    ligand : Ligand
        The input ligand to be randomized.

    Returns:
    -------
    None
    """
    true_pos = np.copy(ligand.orig_pos)
    pred_pos = np.copy(ligand.predicted_pos)

    if len(true_pos) != len(pred_pos):
        print('Cut H from pred_pos', name, len(pred_pos), len(true_pos))
        pred_pos = pred_pos[:len(true_pos)]

    # Tor:
    num_rotatable_bonds = ligand.rotatable_bonds.shape[0]
    if num_rotatable_bonds > 0:
        bond_properties_for_angles = get_bond_properties_for_angles(ligand.rotatable_bonds_ext)
        true_bond_periods = bond_properties_for_angles['bond_periods']
        bond_properties_for_angles['bond_periods'] = np.ones_like(true_bond_periods) * 2 * np.pi
        angles_true = get_torsion_angles(np.copy(true_pos), bond_atoms_for_angles=bond_properties_for_angles)
        angles_pred = get_torsion_angles(np.copy(pred_pos), bond_atoms_for_angles=bond_properties_for_angles)
        torsion_updates = angles_pred - angles_true

        pos_new, _ = apply_tor_changes_to_pos(np.copy(pred_pos), ligand.rotatable_bonds, ligand.mask_rotate_before_fixing,
                                              torsion_updates, is_reverse_order=True)
    else:
        torsion_updates = np.empty(0).astype(np.float32)
        pos_new = np.copy(pred_pos)

    # compute tr and rot alignment
    rot_align, tr_align = find_rigid_alignment(pos_new, true_pos)

    ligand.init_tr = ligand.pred_tr.reshape(1, 3)
    ligand.final_rot = rot_align[None, :, :]
    ligand.final_tor = torsion_updates
    ligand.pos = pred_pos

    # Time is randomized from Uniform[0, 1]:
    ligand.t = torch.rand(1)
    if ligand.rmsd is None:
        ligand.rmsd = torch.zeros(1)
    ligand.stage_num = torch.tensor([3])


def randomize_complex(complex: Complex, augm_ligand_transforms: bool, std_protein_pos: float,
                      std_lig_pos: float, ligand_mask_ratio: float, protein_mask_ratio: float,
                      tr_mean: float, tr_std: float, use_pred_ligand_transforms: bool=False,
                      use_predicted_tr_only: bool=True, randomize_bond_neighbors: bool=True,
                      stage_num: int = None):
    complex.ligand.pos

    # 1. Rotate complex
    apply_random_rotation_inplace(complex)
    # print('Add apply_random_rotation_inplace')
    # complex.original_augm_rot = np.eye(3, dtype=np.float32)

    # 2. Add random noise to protein residues and ligand atom positions
    # randomize transforms
    if augm_ligand_transforms:
        add_ligand_noise(complex.ligand)

    complex.protein.pos += np.random.normal(0, std_protein_pos, complex.protein.pos.shape)
    complex.ligand.pos += np.random.normal(0, std_lig_pos, complex.ligand.pos.shape)

    # 4. Shift to the protein center
    complex.shift_to_protein_center()

    # 5. Compute ligand gt values
    complex.set_ground_truth_values()

    # Randomly mask protein and ligand atoms
    complex.randomly_mask_complex(ligand_mask_ratio=ligand_mask_ratio, protein_mask_ratio=protein_mask_ratio)

    # 6. Randomly swap chains
    # complex.protein.randomly_swap_chains()

    # 6. Randomly sample neighbors of rotatable bonds
    if randomize_bond_neighbors:
        complex.ligand.randomly_sample_neighbors_of_rotatable_bonds()
    else:
        complex.ligand.sample_first_neighbor_of_rotatable_bonds()
    if complex.ligand.rotatable_bonds_ext.start.shape[0] > 0:
        bond_properties_for_angles = get_bond_properties_for_angles(complex.ligand.rotatable_bonds_ext)
        complex.ligand.rotatable_bonds_ext.angles = get_torsion_angles(np.copy(complex.ligand.pos), 
                                                                       bond_atoms_for_angles=bond_properties_for_angles)

    # 7. Randomize ligand for NN input
    if use_pred_ligand_transforms:
        if use_predicted_tr_only:
            randomize_ligand_with_preds(complex.ligand)
        else:
            set_ligand_data_from_preds(complex.ligand, complex.name)
    else:
        randomize_ligand_3stages(complex.ligand, tr_mean=tr_mean, tr_std=tr_std, stage_num=stage_num)
    return complex


class PDBBind(Dataset):
    def __init__(self, data_dir, split_path, esm_embeddings_path, sequences_path, max_lig_size,
                 tr_std=1., tr_mean=None, augm_ligand_transforms=False,
                 no_cache=False, cache_path='data/cache', limit_complexes=0, num_dataset_workers=1,
                 remove_hs=True, std_protein_pos=0.1, std_lig_pos=0.1,
                 num_new_conformations=0, ligand_mask_ratio=0., protein_mask_ratio=0.,
                 predicted_ligand_transforms_path=None, dataset_type='pdbbind',
                 chain_mapping_path=None, inverse_crop_ids_path=None, add_all_atom_pos=False,
                 use_predicted_tr_only=True, randomize_bond_neighbors=True,
                 data_dir_conf=None, is_train_dataset=True,
                 n_preds_to_use=1, use_all_chains=False,
                 min_lig_size=7, stage_num=None):
        self.data_dir = data_dir
        self.data_dir_conf = data_dir_conf
        self.limit_complexes = limit_complexes
        self.esm_embeddings_path = esm_embeddings_path
        self.sequences_path = sequences_path
        self.cache_path = cache_path
        self.max_lig_size = max_lig_size
        self.no_cache = no_cache
        self.num_dataset_workers = num_dataset_workers
        self.num_new_conformations = num_new_conformations
        self.dataset_type = dataset_type
        self.tr_std = tr_std
        self.tr_mean = tr_mean
        self.remove_hs = remove_hs
        self.std_protein_pos = std_protein_pos
        self.std_lig_pos = std_lig_pos
        self.augm_ligand_transforms = augm_ligand_transforms
        self.ligand_mask_ratio = ligand_mask_ratio
        self.protein_mask_ratio = protein_mask_ratio
        self.chain_mapping_path = chain_mapping_path
        self.inverse_crop_ids_path = inverse_crop_ids_path
        self.use_pred_ligand_transforms = predicted_ligand_transforms_path is not None
        self.add_all_atom_pos = add_all_atom_pos
        self.use_predicted_tr_only = use_predicted_tr_only
        self.randomize_bond_neighbors = randomize_bond_neighbors
        self.is_train_dataset = is_train_dataset
        self.use_all_chains = use_all_chains
        self.min_lig_size = min_lig_size
        self.stage_num = stage_num

        if isinstance(split_path, tuple):
            self.split_path = split_path[0]
            filtered_split_path = split_path[1]
        else:
            self.split_path = split_path
        
        self.full_cache_path = self._get_cache_folder_path()

        # TODO keep 0 for padding
        aa_list = ['-', 'A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q',
                   'R', 'S', 'T', 'V', 'W', 'Y']
        self.aa_mapping = {aa: i for i, aa in enumerate(aa_list)}

        # loads data to self.complexes list:
        if not self.no_cache and os.path.exists(self.full_cache_path):
            self._load_from_cache()
        else:
            os.makedirs(self.full_cache_path, exist_ok=True)
            self._preprocess_and_save_to_cache()

        if self.dataset_type.endswith('_conf'):
            self._set_all_conformer_proteins()

        # save orig_pos_before_augm for each ligand
        complexes = []
        for complex in self.complexes:
            complex.ligand.true_pos = np.copy(complex.ligand.pos)

            complex.ligand.orig_pos_before_augm = np.copy(complex.ligand.pos)
            complex.protein.full_pos = np.copy(complex.protein.pos)
            complexes.append(complex)
        self.complexes = complexes

        # Initialize chain_ids and aa_ids for proteins
        for complex in self.complexes:
            complex.protein.init_chain_ids_and_aa_ids()

        if self.is_train_dataset:
            print('before', len(self.complexes))
            self.complexes = [compl for compl in self.complexes
                            if compl.ligand.pos.shape[0] < 150 and compl.ligand.pos.shape[0] > 6]
            self.complexes = [compl for compl in self.complexes
                            if compl.protein.pos.shape[0] + compl.ligand.pos.shape[0] < 2000] # 2000
            print('after', len(self.complexes))

        if self.dataset_type == 'moad':
            self.name2index = {complex.name: i for i, complex in enumerate(self.complexes)}
            complex_names_all = set(self.name2index.keys())

            # group complexes by proteins
            self.cluster_to_ligands = defaultdict(list)
            for name in complex_names_all:
                self.cluster_to_ligands[name.split('_superlig')[0]].append(name)

            # remove complexes with cucurbituril (6f7w_1)
            self.cluster_to_ligands.pop('6f7w_1', None)

            self.cluster_to_ligands = list(self.cluster_to_ligands.values())

        if self.dataset_type.endswith('_conf'):
            self._explode_ligand_conformers(n_preds_to_use)
        else:
            self.complexes = self.complexes * n_preds_to_use

        if self.use_pred_ligand_transforms:
            self._set_predicted_ligand_transforms(predicted_ligand_transforms_path, n_preds_to_use)


    def _explode_ligand_conformers(self, n_preds_to_use):

        name2complexes = defaultdict(list)
        for complex in self.complexes:
            name2complexes[complex.name.split('_conf')[0]].append(complex)

        new_complexes = []
        for name, conformers in name2complexes.items():
            while len(conformers) < n_preds_to_use:
                new_conformers = copy.deepcopy(conformers)
                for i, conformer in enumerate(new_conformers):
                    conformer.name = conformer.name.split('_conf')[0] + f'_conf{len(conformers)+i}'
                conformers = conformers + new_conformers
            if len(conformers) > n_preds_to_use:
                conformers = conformers[:n_preds_to_use]
            new_complexes.extend(conformers)

        self.complexes = new_complexes

    def _set_all_conformer_proteins(self):
        name2protein = {}
        for complex in self.complexes:
            if complex.name.endswith('_conf0'):
                name2protein[complex.name.split('_conf')[0]] = complex.protein # copy.deepcopy(complex.protein)
        
        new_complexes = []
        for complex in self.complexes:
            if not complex.name.endswith('_conf0'):
                complex.protein = name2protein[complex.name.split('_conf')[0]] #copy.deepcopy(name2protein[complex.name.split('_conf')[0]])
            new_complexes.append(complex)
        self.complexes = new_complexes

    # def _set_predicted_ligand_transforms(self, predicted_ligand_transforms_path, n_preds_to_use):
    #     self.predicted_ligand_transforms = np.load(predicted_ligand_transforms_path, allow_pickle=True)[0]
    #     self.n_repeats = 1
    #     self.n_preds_to_use = min(n_preds_to_use, len(self.predicted_ligand_transforms[self.complexes[0].name]))
    #     self.complexes = [complex for complex in self.complexes if complex.name in self.predicted_ligand_transforms]

    #     # initialize extended complexes
    #     extended_complexes = []
    #     for complex in tqdm(self.complexes, desc='Setting predicted ligand transforms...'):
    #         for i in range(self.n_preds_to_use):
    #             extended_complex = copy.deepcopy(complex)
    #             pred_data = self.predicted_ligand_transforms[complex.name][i]
    #             extended_complex.ligand.pred_tr = pred_data['tr_pred_init'] + pred_data['full_protein_center'] - extended_complex.protein.full_protein_center
    #             if not self.use_predicted_tr_only:
    #                 extended_complex.ligand.predicted_pos = pred_data['transformed_orig']
    #             extended_complexes.append(extended_complex)
    #     self.complexes = extended_complexes

    
    def _set_predicted_ligand_transforms(self, predicted_ligand_transforms_path, n_preds_to_use):
        self.predicted_ligand_transforms = np.load(predicted_ligand_transforms_path, allow_pickle=True)[0]
        self.n_repeats = 1
        self.n_preds_to_use = min(n_preds_to_use, len(self.predicted_ligand_transforms[self.complexes[0].name]))
        self.complexes = [complex for complex in self.complexes if complex.name in self.predicted_ligand_transforms]

        # initialize extended complexes
        extended_complexes = []
        for complex in tqdm(self.complexes, desc='Setting predicted ligand transforms...'):
            for i in range(self.n_preds_to_use):
                extended_complex = copy.deepcopy(complex)
                pred_data = self.predicted_ligand_transforms[complex.name][i]
                extended_complex.ligand.pred_tr = pred_data['tr_pred_init'] + pred_data['full_protein_center'] - extended_complex.protein.full_protein_center
                
                pred_pos = pred_data['transformed_orig'] + pred_data['full_protein_center'] - extended_complex.protein.full_protein_center
                if not self.use_predicted_tr_only:
                    extended_complex.ligand.predicted_pos = pred_pos
                
                extended_complexes.append(extended_complex)
        self.complexes = extended_complexes


    def __len__(self):
        if self.dataset_type == 'moad':
            return len(self.cluster_to_ligands)
        return len(self.complexes)
    

    def __get_nonrand_item__(self, idx):
        if self.dataset_type == 'moad':
            complex_name = np.random.choice(self.cluster_to_ligands[idx])
            complex_idx = self.name2index[complex_name]
        else:
            complex_idx = idx

        complex = copy.deepcopy(self.complexes[complex_idx])
        return complex


    def __getitem__(self, idx):
        complex = self.__get_nonrand_item__(idx)
        complex = randomize_complex(complex=complex, augm_ligand_transforms=self.augm_ligand_transforms,
                                    std_protein_pos=self.std_protein_pos, std_lig_pos=self.std_lig_pos,
                                    ligand_mask_ratio=self.ligand_mask_ratio,
                                    protein_mask_ratio=self.protein_mask_ratio,
                                    tr_mean=self.tr_mean, tr_std=self.tr_std,
                                    use_pred_ligand_transforms=self.use_pred_ligand_transforms,
                                    use_predicted_tr_only=self.use_predicted_tr_only,
                                    randomize_bond_neighbors=self.randomize_bond_neighbors,
                                    stage_num=self.stage_num)
        return complex


    def _get_cache_folder_path(self):
        values_for_cache_path = [self.dataset_type, self.limit_complexes, self.max_lig_size,
                                 os.path.basename(self.esm_embeddings_path),
                                 os.path.basename(self.split_path)]
        str_for_cache_path = map(str, values_for_cache_path)
        args_str = '_'.join(str_for_cache_path)
        # replace any unsafe characters:
        pattern = r'[^A-Za-z0-9\-_]'
        safe_args_str = re.sub(pattern, '_', args_str)
        if self.num_new_conformations > 0:
            safe_args_str = f'conf{self.num_new_conformations}_' + safe_args_str
        if self.use_all_chains:
            safe_args_str = f'allchains_' + safe_args_str
        cache_folder_path = os.path.join(self.cache_path, safe_args_str)
        return cache_folder_path


    def _load_embeddings(self, embeddings_path, sequences_path, complex_names, protein_to_complex_names=None):
        try:
            id_to_embeddings = torch.load(embeddings_path, weights_only=False)
            id_to_sequence = load(sequences_path)
            chain_mapping = None
            if self.chain_mapping_path is not None:
                chain_mapping = load(self.chain_mapping_path)
                inverse_chain_mapping = defaultdict(list)
                for k, v in chain_mapping.items():
                    inverse_chain_mapping[v].append(k)
            inverse_crop_ids = None
            if self.inverse_crop_ids_path is not None:
                inverse_crop_ids = load(self.inverse_crop_ids_path)
        except FileNotFoundError:
            raise ValueError(f"Embeddings file not found at {embeddings_path} or sequences file not found at {sequences_path}")
        except Exception as e:
            raise ValueError(f"An error occurred while loading embeddings: {e}")

        chain_embeddings_dictlist = defaultdict(list)
        chain_sequences_dictlist = defaultdict(list)
        tokenized_chain_sequences_dictlist = defaultdict(list)

        complex_names_set = set(complex_names)
        for key_base, embedding in id_to_embeddings.items():
            # map mulan chains to original esm chains
            if chain_mapping is not None:
                keys_all = inverse_chain_mapping[key_base]
            else:
                keys_all = [key_base]
            for key in keys_all:
                try:
                    key_name = '_'.join(key.split('_')[:-2]) # cut _chain_i
                except IndexError:
                    raise ValueError(f"Invalid key format in embeddings: {key}")

                if (self.dataset_type == 'moad' and key_name in protein_to_complex_names) or key_name in complex_names_set:

                    tokenized_aa_sequence = np.array([self.aa_mapping.get(aa, 0) for aa in id_to_sequence[key]])[:, None]
                    aa_sequence = np.array([aa for aa in id_to_sequence[key]])

                    if inverse_crop_ids is not None and key_name in inverse_crop_ids:
                        indices_to_crop = inverse_crop_ids[key_name]
                        print('CROP', key_name, indices_to_crop)
                        indices_to_keep = sorted(set(range(len(embedding))) - set(indices_to_crop))
                        embedding = embedding[indices_to_keep]
                        aa_sequence = aa_sequence[indices_to_keep]
                        tokenized_aa_sequence = tokenized_aa_sequence[indices_to_keep]

                    chain_embeddings_dictlist[key_name].append(embedding)
                    chain_sequences_dictlist[key_name].append(aa_sequence)
                    tokenized_chain_sequences_dictlist[key_name].append(tokenized_aa_sequence)                    

        if self.dataset_type == 'moad':
            return chain_embeddings_dictlist, chain_sequences_dictlist, tokenized_chain_sequences_dictlist

        lm_embeddings_chains_all = [chain_embeddings_dictlist.get(name, []) for name in complex_names]
        sequence_chains_all = [chain_sequences_dictlist.get(name, []) for name in complex_names]
        tokenized_sequence_chains_all = [tokenized_chain_sequences_dictlist.get(name, []) for name in complex_names]
        print('LLM embeddings are loaded.', flush=True)
        return lm_embeddings_chains_all, sequence_chains_all, tokenized_sequence_chains_all


    def _process_complex(self, complex_names, sequences_to_embeddings):
        try:
            return self._get_complex(complex_names, sequences_to_embeddings)
        except Exception as e:
            print(f"Error processing {complex_names}: {e}")
            return None


    def _preprocess_and_save_to_cache(self):
        print(f'Processing complexes from [{self.split_path}] and saving it to [{self.full_cache_path}]')

        # Get names of complexes:
        with open(self.split_path, 'r') as file:
            lines = file.readlines()
            complex_names_all = [line.rstrip() for line in lines]

        if self.limit_complexes is not None and self.limit_complexes != 0:
            complex_names_all = complex_names_all[:self.limit_complexes]
        print(f'Loading {len(complex_names_all)} complexes.')

        # Load embeddings:
        if self.dataset_type == 'moad':
            protein_to_complex_names = defaultdict(list)
            for name in complex_names_all:
                protein_to_complex_names[name.split('_superlig')[0]].append(name)
        else:
            protein_to_complex_names = None
        lm_embeddings_chains_all, sequence_chains_all, tokenized_sequence_chains_all = self._load_embeddings(self.esm_embeddings_path,
                                                                              self.sequences_path,
                                                                              complex_names_all,
                                                                              protein_to_complex_names=protein_to_complex_names)

        self.complexes = []

        if self.dataset_type == 'moad':
            for protein_name, protein_complexes in tqdm(protein_to_complex_names.items(), desc='Loading complexes'):
                lm_embeddings = lm_embeddings_chains_all[protein_name]
                sequence_chains = sequence_chains_all[protein_name]
                tokenized_sequence_chains = tokenized_sequence_chains_all[protein_name]
                sequences_to_embeddings = {''.join(seq): (emb, tokenized_seq) for seq, emb, tokenized_seq in zip(sequence_chains, lm_embeddings,
                                                                                                                    tokenized_sequence_chains)}
                processed_complexes = self._process_complex(protein_complexes, sequences_to_embeddings)
                if processed_complexes is not None:
                    self.complexes += processed_complexes

        else:
            with tqdm(total=len(complex_names_all), desc='Loading complexes') as pbar:
                for complex_name, lm_embeddings, sequence_chains, tokenized_sequence_chains in zip(complex_names_all,
                                                                                                    lm_embeddings_chains_all,
                                                                                                    sequence_chains_all,
                                                                                                    tokenized_sequence_chains_all):
                    sequences_to_embeddings = {''.join(seq): (emb, tokenized_seq) for seq, emb, tokenized_seq in zip(sequence_chains, lm_embeddings,
                                                                                                                        tokenized_sequence_chains)}
                    processed_complexes = self._process_complex([complex_name], sequences_to_embeddings)
                    if processed_complexes is not None:
                        self.complexes += processed_complexes
                    pbar.update()

        # Filter out empty complexes:
        self.complexes = [complex for complex in self.complexes if (complex.ligand is not None) and (complex.protein is not None)]

        # Save:
        filepath = os.path.join(self.full_cache_path, 'complexes.pkl')
        try:
            with open(filepath, 'wb') as f:
                pickle.dump(self.complexes, f)
            print(f"Data successfully saved to {filepath}!")
        except IOError as e:
            print(f"Error saving data to {filepath}: {e}!")


    def _load_from_cache(self):
        filepath = os.path.join(self.full_cache_path, 'complexes.pkl')
        try:
            with open(filepath, 'rb') as f:
                self.complexes = pickle.load(f)
            print(f"Data successfully loaded from {filepath}!")
        except IOError as e:
            print(f"Error loading data from {filepath}: {e}!")


    def _get_complex(self, complex_names, sequences_to_embeddings):
        print('complex_names', complex_names)
        try:
            rec_model = parse_receptor(complex_names[0], self.data_dir, self.dataset_type)
        except Exception as e:
            print(f'Skipping {complex_names[0]} because of the error:')
            print(e)
            return [], []

        complexes = []
        failed_indices = []
        for name in complex_names:
            print('complex', name)

            orig_ligs = None
            if self.dataset_type == 'pdbbind' or self.dataset_type == 'lpce':
                ligs = read_mols(self.data_dir, name, remove_hs=False)
            elif self.dataset_type == 'pdbbind_conf' or self.dataset_type == 'dockgen_full_conf' or \
                self.dataset_type == 'posebusters_conf' or self.dataset_type == 'astex_conf':
                ligs = [read_sdf_with_multiple_confs(os.path.join(self.data_dir_conf, f'{name}_conf.sdf'), remove_hs=False, sanitize=True)]
                if self.dataset_type == 'pdbbind_conf':
                    orig_ligs = read_mols(self.data_dir, name, remove_hs=False)
                elif self.dataset_type == 'posebusters_conf' or self.dataset_type == 'astex_conf':
                    orig_ligs = [read_molecule(os.path.join(self.data_dir, name, f'{name}_ligand.sdf'), remove_hs=False, sanitize=True)]
                else:
                    orig_ligs = [read_molecule(os.path.join(self.data_dir, name, f'{name}_ligand.pdb'), remove_hs=False, sanitize=True)]
            elif self.dataset_type == 'moad':
                ligs = [read_molecule(os.path.join(self.data_dir, 'pdb_superligand', f'{name}.pdb'), remove_hs=False, sanitize=True)]
            elif self.dataset_type == 'dockgen' or self.dataset_type == 'dockgen_full':
                ligs = [read_molecule(os.path.join(self.data_dir, name, f'{name}_ligand.pdb'), remove_hs=False, sanitize=True)]
            elif self.dataset_type == 'astex' or self.dataset_type == 'posebusters':
                ligs = [read_molecule(os.path.join(self.data_dir, name, f'{name}_ligand.sdf'), remove_hs=False, sanitize=True)]
            else:
                raise ValueError(f'Unknown dataset type: {self.dataset_type}')
                        
            if len(ligs) > 0 and type(ligs[0]) == list:
                ligs = [split_molecule(lig_mol, min_lig_size=self.min_lig_size) for lig_mol in ligs[0]]
                ligs = [lig_mol for cur_lig_mol_list in ligs for lig_mol in cur_lig_mol_list if lig_mol is not None]
                ligs = [ligs]
            else:
                ligs = [split_molecule(lig_mol, min_lig_size=self.min_lig_size) for lig_mol in ligs]
                ligs = [lig_mol for lig_mol_list in ligs for lig_mol in lig_mol_list if lig_mol is not None]

            if orig_ligs is not None:
                orig_ligs = [split_molecule(lig_mol, min_lig_size=self.min_lig_size) for lig_mol in orig_ligs]
                orig_ligs = [lig_mol for lig_mol_list in orig_ligs for lig_mol in lig_mol_list if lig_mol is not None]

            for lig_idx, lig_mol in enumerate(ligs):
                if type(lig_mol) == list: # multiple conformations
                    lig_mol_list = lig_mol
                    lig_mol = lig_mol[0]
                else:
                    lig_mol_list = [lig_mol]

                if self.max_lig_size is not None and lig_mol.GetNumHeavyAtoms() > self.max_lig_size:
                    print(f'Ligand with {lig_mol.GetNumHeavyAtoms()} heavy atoms is larger than max_lig_size {self.max_lig_size}. Not including {name} in preprocessed data.')
                    continue

                try:
                    ### Process protein:
                    if self.dataset_type == 'astex_conf' or self.dataset_type == 'posebusters_conf' or \
                        self.dataset_type == 'pdbbind_conf' or self.dataset_type == 'dockgen_full_conf':
                        c_alpha_coords_list, lm_embeddings_list, sequences_list, chain_lengths, full_coords, full_atom_names = extract_receptor_structure_prody(
                            copy.deepcopy(rec_model), orig_ligs[lig_idx] if not self.use_all_chains else None, sequences_to_embeddings)
                    else:
                        # c_alpha_coords_list, lm_embeddings_list, sequences_list, chain_lengths, full_coords, full_atom_names = extract_receptor_structure_prody(
                        #     copy.deepcopy(rec_model), lig_mol, sequences_to_embeddings)
                        c_alpha_coords_list, lm_embeddings_list, sequences_list, chain_lengths, full_coords, full_atom_names = extract_receptor_structure_prody(
                            copy.deepcopy(rec_model), lig_mol if not self.use_all_chains else None, sequences_to_embeddings)
                    
                    # if lm_embeddings_list is not None and len(c_alpha_coords_list) != len(lm_embeddings_list):
                    #     print(f'LM embeddings for complex {name} did not have the right length for the protein. Skipping {name}.')
                    #     failed_indices.append(lig_idx)
                    #     continue

                    # positions are positions of C-alpha, other positions are not used
                    if not self.add_all_atom_pos:
                        full_coords = None
                        full_atom_names = None
                    protein = Protein(x=lm_embeddings_list, pos=c_alpha_coords_list, seq=sequences_list, all_atom_pos=full_coords, all_atom_names=full_atom_names)
                    protein_center = protein.pos.mean(axis=0).reshape(1, 3)
                    protein.pos -= protein_center
                    protein.full_protein_center = protein_center
                    protein.chain_lengths = chain_lengths
                    ### Process ligand:
                    parse_rotbonds = True
                    for conf_id, lig_mol in enumerate(lig_mol_list):
                        ligand = get_ligand_without_randomization(lig_mol, protein_center,
                                                                  remove_hs=self.remove_hs,
                                                                  num_new_conformations=1,
                                                                  parse_rotbonds=parse_rotbonds)['conf0']
                        if parse_rotbonds:
                            ligand_with_bonds = copy.deepcopy(ligand)
                            cur_ligand = ligand
                        else:
                            cur_ligand = copy.deepcopy(ligand_with_bonds)
                            cur_ligand.pos = copy.deepcopy(ligand.pos)
                            cur_ligand.x = copy.deepcopy(ligand.x)
                            cur_ligand.final_tr = copy.deepcopy(ligand.final_tr)
                            cur_ligand.orig_mol = copy.deepcopy(ligand.orig_mol)

                        parse_rotbonds = False
                        complex = Complex()
                        complex.ligand = cur_ligand
                        if conf_id == 0:
                            complex.protein = copy.deepcopy(protein)
                        else:
                            complex.protein = [] # avoid copying protein in cache

                        if self.dataset_type.endswith('_conf'):
                            complex.name = f'{name}_mol{lig_idx}_conf{conf_id}'
                        else:
                            complex.name = f'{name}_mol{lig_idx}'
                        complexes.append(complex)

                except Exception as e:
                    print(f'Skipping {name} because of the error:')
                    print(e)
                    failed_indices.append(lig_idx)
                    continue

        return complexes


class SameComplexPDBBind(Dataset):
    def __init__(self, dataset, batch_size, data_collator):
        self.dataset = dataset        
        self.batch_size = batch_size
        self.data_collator = data_collator

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        batch_complexes = []
        for i in range(self.batch_size):
            complex = self.dataset.__get_nonrand_item__(idx)
            complex = randomize_complex(complex=complex, augm_ligand_transforms=self.dataset.augm_ligand_transforms,
                                        std_protein_pos=self.dataset.std_protein_pos, std_lig_pos=self.dataset.std_lig_pos,
                                        ligand_mask_ratio=self.dataset.ligand_mask_ratio,
                                        protein_mask_ratio=self.dataset.protein_mask_ratio,
                                        tr_mean=self.dataset.tr_mean, tr_std=self.dataset.tr_std,
                                        use_pred_ligand_transforms=self.dataset.use_pred_ligand_transforms,
                                        use_predicted_tr_only=self.dataset.use_predicted_tr_only,
                                        randomize_bond_neighbors=self.dataset.randomize_bond_neighbors,
                                        stage_num=self.dataset.stage_num)
            batch_complexes.append(complex)
        return self.data_collator(batch_complexes)
    

class PDBBindWithSortedBatching(Dataset):
    def __init__(self, dataset, batch_limit, data_collator):
        self.dataset = dataset        
        self.batch_limit = batch_limit
        self.data_collator = data_collator

        if self.dataset.dataset_type == 'moad':
            raise ValueError('use_sorted_batching is not supported for moad dataset')
        self._form_batches(batch_limit)

    def _init_sorted_indices(self):
        protein_lengths = np.array([complex.protein.pos.shape[0] for complex in self.dataset.complexes])
        ligand_lengths = np.array([complex.ligand.pos.shape[0] for complex in self.dataset.complexes])
        sorted_indices = np.lexsort((ligand_lengths, protein_lengths))
        return protein_lengths + ligand_lengths, sorted_indices

    def _get_sorted_batches(self, lengths, sorted_indices, batch_limit):
        batch_indices = []
        cur_batch = []
        for real_ind, cur_len in zip(sorted_indices, lengths[sorted_indices]):
            if (len(cur_batch) + 1) * cur_len <= batch_limit:
                cur_batch.append(real_ind)
            else:
                batch_indices.append(cur_batch)
                cur_batch = [real_ind]

        batch_indices.append(cur_batch)
        return batch_indices

    def _form_batches(self, batch_limit):
        lengths, sorted_indices = self._init_sorted_indices()
        self.batch_indices = self._get_sorted_batches(lengths, sorted_indices, batch_limit)

    def __len__(self):
        return len(self.batch_indices)

    def __getitem__(self, idx):
        batch_complexes = []
        for i in self.batch_indices[idx]:
            complex = self.dataset.__get_nonrand_item__(i)
            complex = randomize_complex(complex=complex, augm_ligand_transforms=self.dataset.augm_ligand_transforms,
                                        std_protein_pos=self.dataset.std_protein_pos, std_lig_pos=self.dataset.std_lig_pos,
                                        ligand_mask_ratio=self.dataset.ligand_mask_ratio,
                                        protein_mask_ratio=self.dataset.protein_mask_ratio,
                                        tr_mean=self.dataset.tr_mean, tr_std=self.dataset.tr_std,
                                        use_pred_ligand_transforms=self.dataset.use_pred_ligand_transforms,
                                        use_predicted_tr_only=self.dataset.use_predicted_tr_only if hasattr(self.dataset, 'use_predicted_tr_only') else True,
                                        randomize_bond_neighbors=self.dataset.randomize_bond_neighbors,
                                        stage_num=self.dataset.stage_num)
            batch_complexes.append(complex)
        return self.data_collator(batch_complexes)


def apply_random_rotation_inplace(complex):
    aug_rot = R.random().as_matrix().astype(np.float32)

    complex.ligand.pos = complex.ligand.pos @ aug_rot.T
    if complex.ligand.pred_tr is not None:
        complex.ligand.pred_tr = complex.ligand.pred_tr @ aug_rot.T
    if complex.ligand.predicted_pos is not None:
        complex.ligand.predicted_pos = complex.ligand.predicted_pos @ aug_rot.T
    complex.protein.pos = complex.protein.pos @ aug_rot.T
    if complex.protein.full_pos is not None:
        complex.protein.full_pos = complex.protein.full_pos @ aug_rot.T
    complex.original_augm_rot = aug_rot


def complex_collate_fn(batch: List[Complex]) -> ComplexBatch:
    """
    Collate function to pad sequences and output a ComplexBatch.

    Parameters:
    batch (List[Complex]): A list of Complex objects, where each Complex contains:
        - ligand (Ligand): The ligand object with attributes x and pos.
        - protein (Protein): The protein object with attributes x and pos.

    Returns:
    ComplexBatch: A batch object containing padded sequences for ligands and proteins.
    """

    # Extract components from the batch
    lig_xs = [torch.from_numpy(complex.ligand.x) for complex in batch]
    lig_positions = [torch.from_numpy(complex.ligand.pos) for complex in batch]
    lig_orig_positions = [torch.from_numpy(complex.ligand.orig_pos) for complex in batch]
    lig_orig_positions_before_augm = [torch.from_numpy(complex.ligand.orig_pos_before_augm) for complex in batch]
    lig_true_positions = [torch.from_numpy(complex.ligand.true_pos) for complex in batch]
    orig_mols = [complex.ligand.orig_mol for complex in batch]
    mask_rotate = [torch.from_numpy(complex.ligand.mask_rotate) for complex in batch]
    protein_xs = [torch.from_numpy(complex.protein.x) if isinstance(complex.protein.x, np.ndarray) else complex.protein.x for complex in batch]
    protein_positions = [torch.from_numpy(complex.protein.pos) for complex in batch]
    protein_sequences = [torch.from_numpy(complex.protein.seq) for complex in batch]
    protein_chain_ids = [torch.from_numpy(complex.protein.chain_ids) for complex in batch]
    protein_aa_ids = [torch.from_numpy(complex.protein.aa_ids) for complex in batch]
    init_tr = torch.cat([torch.from_numpy(complex.ligand.init_tr) for complex in batch])
    init_rot = torch.cat([torch.from_numpy(complex.ligand.init_rot) for complex in batch])
    init_tor = torch.cat([torch.from_numpy(complex.ligand.init_tor) for complex in batch])
    final_tr = torch.cat([torch.from_numpy(complex.ligand.final_tr) for complex in batch])
    final_rot = torch.cat([torch.from_numpy(complex.ligand.final_rot) for complex in batch])
    final_tor = torch.cat([torch.from_numpy(complex.ligand.final_tor) for complex in batch])
    try:
        pred_tor_angles = torch.cat([torch.from_numpy(complex.ligand.pred_tor_angles) for complex in batch])
    except:
        pred_tor_angles = None
    try:
        pred_tor_mask = torch.cat([torch.from_numpy(complex.ligand.pred_tor_mask) for complex in batch])
    except:
        pred_tor_mask = None

    try:
        all_atom_pos = [torch.from_numpy(complex.protein.all_atom_pos) for complex in batch]
        all_atom_names = [complex.protein.all_atom_names for complex in batch]
    except:
        all_atom_pos = None
        all_atom_names = None

    num_rotatable_bonds = torch.tensor([len(complex.ligand.final_tor) for complex in batch], dtype=torch.long)
    t = torch.cat([complex.ligand.t for complex in batch])
    rmsd = torch.cat([complex.ligand.rmsd for complex in batch])
    stage_num = torch.cat([complex.ligand.stage_num if complex.ligand.stage_num is not None else torch.tensor([0]) for complex in batch])
    names = [complex.name for complex in batch]
    orig_augm_rot = torch.cat([torch.from_numpy(complex.original_augm_rot[None, :]) for complex in batch])
    orig_pocket_center = torch.cat([torch.from_numpy(complex.original_pocket_center) for complex in batch])
    try:
        full_protein_center = torch.cat([torch.from_numpy(complex.protein.full_protein_center) for complex in batch])
    except:
        full_protein_center = None

    try:
        pred_tr = torch.cat([torch.from_numpy(complex.ligand.pred_tr) for complex in batch])
    except:
        pred_tr = None

    # Pad ligand sequences
    lig_x_padded = pad_sequence(lig_xs, batch_first=True, padding_value=0.0)
    lig_pos_padded = pad_sequence(lig_positions, batch_first=True, padding_value=0.0)
    try:
        lig_orig_pos_padded = pad_sequence(lig_orig_positions, batch_first=True, padding_value=0.0)
    except:
        lig_orig_pos_padded = None
        print('Warning: orig_pos are not defined!')

    try:
        lig_orig_pos_before_augm_padded = pad_sequence(lig_orig_positions_before_augm,
                                                       batch_first=True, padding_value=0.0)
    except:
        lig_orig_pos_before_augm_padded = None
        print('Warning: orig_pos_before_augm are not defined!')

    try:
        lig_true_pos_padded = pad_sequence(lig_true_positions, batch_first=True, padding_value=0.0)
    except:
        lig_true_pos_padded = None
        print('Warning: true_pos are not defined!')

    # Pad protein sequences
    protein_x_padded = pad_sequence(protein_xs, batch_first=True, padding_value=0.0)
    protein_pos_padded = pad_sequence(protein_positions, batch_first=True, padding_value=0.0)
    protein_seq_padded = pad_sequence(protein_sequences, batch_first=True, padding_value=0.0)
    protein_chain_ids_padded = pad_sequence(protein_chain_ids, batch_first=True, padding_value=0.0)
    protein_aa_ids_padded = pad_sequence(protein_aa_ids, batch_first=True, padding_value=0.0)

    rotatable_bonds_list = []
    for complex in batch:
        if len(complex.ligand.rotatable_bonds) > 0:
            rotatable_bonds_list.append(torch.from_numpy(complex.ligand.rotatable_bonds))
    if len(rotatable_bonds_list) > 0:
        rotatable_bonds = torch.concat(rotatable_bonds_list)
    else:
        rotatable_bonds = torch.empty((0, 2))

    bond_periods_list = [torch.from_numpy(complex.ligand.bond_periods) for complex in batch
                            if complex.ligand.bond_periods is not None]
    if len(bond_periods_list) > 0:
        bond_periods = torch.cat(bond_periods_list)
    else:
        bond_periods = torch.empty((0,))

    # NEW CODE FOR BONDS
    # Extract and pad rotatable and non-rotatable bonds
    rotatable_bonds_ext = [complex.ligand.rotatable_bonds_ext for complex in batch]
    non_rotatable_bonds_ext = [complex.ligand.non_rotatable_bonds_ext for complex in batch]

    # Pad bonds
    def pad_bonds(bonds_list, max_num_bonds, max_num_atoms, is_rotatable_bonds=False):
        bond_keys = ['bond_type', 'is_conjugated', 'is_in_ring', 'is_aromatic', 'is_rotatable', 
                     'start', 'end', 'neighbor_of_start', 'neighbor_of_end', 'length', 
                     'bond_periods', 'angles', 'angle_histograms']
        padded_bonds = {key: [] for key in bond_keys}
        for bonds in bonds_list:
            for key in padded_bonds.keys():
                value = getattr(bonds, key)
                if value is not None:
                    # This patch looks ugly, but it is to fix the case where the number of bonds is 0, and start/end are empty
                    # Without this patch, np.pad returns float instead of int
                    if len(value) == 0 and (key == 'start' or key == 'end' or key == 'neighbor_of_start' or 
                                            key == 'neighbor_of_end'):
                        padded_value = np.zeros(max_num_bonds, dtype=np.int32)
                    else:
                        if key == 'length' or key == 'bond_periods' or key == 'angles':
                            value = value.astype(np.float32)
                        constant_value = 2 * np.pi if key == 'bond_periods' else 0
                        if key == 'angle_histograms':
                            padded_value = np.pad(value, ((0, max_num_bonds - len(value)), (0, 0)), 'constant', 
                                                  constant_values=constant_value)
                        else:
                            padded_value = np.pad(value, (0, max_num_bonds - len(value)), 'constant', 
                                                  constant_values=constant_value)
                    padded_bonds[key].append(torch.from_numpy(padded_value))  # Convert to tensor
                else:
                    padded_bonds[key] = None

        res = {}
        for key, value in padded_bonds.items():
            if value is not None:
                res[key] = torch.stack(value)
            else:
                res[key] = None

        res['num_rotatable_bonds'] = torch.tensor([len(bonds.start) for bonds in bonds_list])

        # mask_rotate is a special case because it is a 2D tensor
        if is_rotatable_bonds:
            padded_bonds['mask_rotate'] = []

            for bonds in bonds_list:
                if len(bonds.start):
                    value = bonds.mask_rotate
                    padded_value = np.pad(value, ((0, max_num_bonds - value.shape[0]), (0, max_num_atoms - value.shape[1])), 'constant', constant_values=0)
                    padded_bonds['mask_rotate'].append(torch.from_numpy(padded_value))
                else:
                    padded_bonds['mask_rotate'].append(torch.zeros(max_num_bonds, max_num_atoms, dtype=torch.bool))

            res['mask_rotate'] = torch.stack(padded_bonds['mask_rotate'])

        max_num_bonds = max([bonds.start.shape[0] for bonds in bonds_list])

        res['is_padded_mask'] = torch.ones(len(bonds_list), max_num_bonds, dtype=torch.bool)
        for idx, bonds in enumerate(bonds_list):
            res['is_padded_mask'][idx, :bonds.start.shape[0]] = False

        return res

    max_rotatable_bonds_ext = max(len(bond.start) for bond in rotatable_bonds_ext)
    max_non_rotatable_bonds_ext = max(len(bond.start) for bond in non_rotatable_bonds_ext)
    max_num_atoms = lig_pos_padded.shape[1]

    padded_rotatable_bonds_ext = pad_bonds(rotatable_bonds_ext, max_rotatable_bonds_ext, max_num_atoms, is_rotatable_bonds=True)
    padded_non_rotatable_bonds_ext = pad_bonds(non_rotatable_bonds_ext, max_non_rotatable_bonds_ext, max_num_atoms)

    # Create BondsBatch objects
    rotatable_bonds_batch_ext = BondsBatch(**padded_rotatable_bonds_ext)
    non_rotatable_bonds_batch_ext = BondsBatch(**padded_non_rotatable_bonds_ext)

    # Create tensors indicating the number of rotatable and non-rotatable bonds in each ligand
    num_rotatable_bonds_ext = torch.tensor([len(bond.start) for bond in rotatable_bonds_ext], dtype=torch.long)
    num_non_rotatable_bonds_ext = torch.tensor([len(bond.start) for bond in non_rotatable_bonds_ext], dtype=torch.long)

    # Fill in is_padded_mask_...
    # We first create a batch_size × max_seq_len matrices, then flatten them
    batch_size, max_lig_seq_len = lig_pos_padded.shape[0], lig_pos_padded.shape[1]
    max_protein_seq_len = protein_pos_padded.shape[1]
    is_padded_mask_ligand = torch.ones(batch_size, max_lig_seq_len, dtype=torch.bool)
    is_padded_mask_protein = torch.ones(batch_size, max_protein_seq_len, dtype=torch.bool)

    for idx, complex in enumerate(batch):
        is_padded_mask_ligand[idx, :complex.ligand.pos.shape[0]] = torch.from_numpy(complex.ligand.is_masked_mask) if complex.ligand.is_masked_mask is not None else False
        is_padded_mask_protein[idx, :complex.protein.pos.shape[0]] = torch.from_numpy(complex.protein.is_masked_mask) if complex.protein.is_masked_mask is not None else False

    # Compute num_atoms and tor_ptr using numpy
    num_atoms = torch.tensor([x.shape[0] for x in lig_positions], dtype=torch.long)
    tor_ptr = [0] + list(np.cumsum([complex.ligand.rotatable_bonds.shape[0] for complex in batch]))

    # Create a padded tor tensor for init_tor and final_tor
    tor_padded_init_ext = pad_sequence(
        [torch.from_numpy(complex.ligand.init_tor) for complex in batch],
        batch_first=True,
        padding_value=0.0
    )
    tor_padded_final_ext = pad_sequence(
        [torch.from_numpy(complex.ligand.final_tor) for complex in batch],
        batch_first=True,
        padding_value=0.0
    )

    # Create ComplexBatch
    batch = ComplexBatch(
        ligand=LigandBatch(
            x=lig_x_padded,
            pos=lig_pos_padded,
            orig_pos=lig_orig_pos_padded,
            orig_pos_before_augm=lig_orig_pos_before_augm_padded,
            true_pos=lig_true_pos_padded,
            random_pos=lig_pos_padded.clone(),
            mask_rotate=mask_rotate,
            init_tr=init_tr,
            init_rot=init_rot,
            init_tor=init_tor,
            final_tr=final_tr,
            final_rot=final_rot,
            final_tor=final_tor,
            pred_tor_angles=pred_tor_angles,
            pred_tor_mask=pred_tor_mask,
            pred_tr=pred_tr,
            num_atoms=num_atoms,
            bond_periods=bond_periods,
            tor_ptr=tor_ptr,
            rotatable_bonds=rotatable_bonds,
            num_rotatable_bonds=num_rotatable_bonds,
            t=t,
            rmsd=rmsd,
            stage_num=stage_num,
            is_padded_mask=is_padded_mask_ligand,
            orig_mols=orig_mols,

            rotatable_bonds_ext=rotatable_bonds_batch_ext,
            non_rotatable_bonds_ext=non_rotatable_bonds_batch_ext,
            num_rotatable_bonds_ext=num_rotatable_bonds_ext,
            num_non_rotatable_bonds_ext=num_non_rotatable_bonds_ext,
            init_tor_ext=tor_padded_init_ext,
            final_tor_ext=tor_padded_final_ext,
        ),
        protein=ProteinBatch(x=protein_x_padded, pos=protein_pos_padded,
                             seq=protein_seq_padded,
                             is_padded_mask=is_padded_mask_protein,
                             full_protein_center=full_protein_center,
                             all_atom_pos=all_atom_pos,
                             all_atom_names=all_atom_names,
                             chain_ids=protein_chain_ids_padded,
                             aa_ids=protein_aa_ids_padded),
        names=names,
        original_pocket_center=orig_pocket_center,
        original_augm_rot=orig_augm_rot,
    )

    return {"batch": batch, "labels": batch.ligand.rmsd}
