import os
import warnings

import numpy as np
import struct
import torch
from Bio.PDB import PDBParser
from rdkit.Chem.rdchem import BondType as BT
from rdkit import Chem
from rdkit.Chem import AllChem, GetPeriodicTable
from rdkit.Geometry import Point3D
from Bio.PDB import PDBParser, MMCIFParser, PDBIO, Select

import prody
from prody import confProDy
confProDy(verbosity='none')



biopython_parser = PDBParser()
periodic_table = GetPeriodicTable()
allowable_features = {
    # 'possible_atomic_num_list': list(range(1, 119)) + ['misc'],
    'possible_atomic_num_list': [1, 5, 6, 7, 8, 9, 12, 14, 15, 16, 17, 26, 33, 34, 35, 44, 45, 51, 53, 75, 77, 78, 'misc'],
    'possible_chirality_list': [
        'CHI_UNSPECIFIED',
        'CHI_TETRAHEDRAL_CW',
        'CHI_TETRAHEDRAL_CCW',
        'CHI_OTHER',
    ],
    'possible_degree_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'],
    'possible_numring_list': [0, 1, 2, 3, 4, 5, 6, 'misc'],
    'possible_implicit_valence_list': [0, 1, 2, 3, 4, 5, 6, 'misc'],
    'possible_formal_charge_list': [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 'misc'],
    'possible_numH_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],
    'possible_number_radical_e_list': [0, 1, 2, 3, 4, 'misc'],
    'possible_hybridization_list': [
        'SP', 'SP2', 'SP3', 'SP3D', 'SP3D2', 'misc'
    ],
    'possible_is_aromatic_list': [False, True],
    'possible_is_in_ring3_list': [False, True],
    'possible_is_in_ring4_list': [False, True],
    'possible_is_in_ring5_list': [False, True],
    'possible_is_in_ring6_list': [False, True],
    'possible_is_in_ring7_list': [False, True],
    'possible_is_in_ring8_list': [False, True],
    'possible_amino_acids': ['ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY', 'HIS', 'ILE', 'LEU', 'LYS', 'MET',
                             'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL', 'HIP', 'HIE', 'TPO', 'HID', 'LEV', 'MEU',
                             'PTR', 'GLV', 'CYT', 'SEP', 'HIZ', 'CYM', 'GLM', 'ASQ', 'TYS', 'CYX', 'GLZ', 'misc'],
    'possible_atom_type_2': ['C*', 'CA', 'CB', 'CD', 'CE', 'CG', 'CH', 'CZ', 'N*', 'ND', 'NE', 'NH', 'NZ', 'O*', 'OD',
                             'OE', 'OG', 'OH', 'OX', 'S*', 'SD', 'SG', 'misc'],
    'possible_atom_type_3': ['C', 'CA', 'CB', 'CD', 'CD1', 'CD2', 'CE', 'CE1', 'CE2', 'CE3', 'CG', 'CG1', 'CG2', 'CH2',
                             'CZ', 'CZ2', 'CZ3', 'N', 'ND1', 'ND2', 'NE', 'NE1', 'NE2', 'NH1', 'NH2', 'NZ', 'O', 'OD1',
                             'OD2', 'OE1', 'OE2', 'OG', 'OG1', 'OH', 'OXT', 'SD', 'SG', 'misc'],
}
bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}

lig_feature_dims = (list(map(len, [
    allowable_features['possible_atomic_num_list'],
    allowable_features['possible_chirality_list'],
    allowable_features['possible_degree_list'],
    allowable_features['possible_formal_charge_list'],
    allowable_features['possible_implicit_valence_list'],
    allowable_features['possible_numH_list'],
    allowable_features['possible_number_radical_e_list'],
    allowable_features['possible_hybridization_list'],
    allowable_features['possible_is_aromatic_list'],
    allowable_features['possible_numring_list'],
    allowable_features['possible_is_in_ring3_list'],
    allowable_features['possible_is_in_ring4_list'],
    allowable_features['possible_is_in_ring5_list'],
    allowable_features['possible_is_in_ring6_list'],
    allowable_features['possible_is_in_ring7_list'],
    allowable_features['possible_is_in_ring8_list'],
])), 0)  # number of scalar features

rec_atom_feature_dims = (list(map(len, [
    allowable_features['possible_amino_acids'],
    allowable_features['possible_atomic_num_list'],
    allowable_features['possible_atom_type_2'],
    allowable_features['possible_atom_type_3'],
])), 0)

rec_residue_feature_dims = (list(map(len, [
    allowable_features['possible_amino_acids']
])), 0)

atom_order = {'G': ['N', 'CA', 'C', 'O'],
'A': ['N', 'CA', 'C', 'O', 'CB'],
'S': ['N', 'CA', 'C', 'O', 'CB', 'OG'],
'C': ['N', 'CA', 'C', 'O', 'CB', 'SG'],
'T': ['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2'],
'P': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD'],
'V': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2'],
'M': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'SD', 'CE'],
'N': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'ND2'],
'I': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1'],
'L': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2'],
'D': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'OD2'],
'E': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'OE2'],
'K': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'CE', 'NZ'],
'Q': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'NE2'],
'H': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'ND1', 'CD2', 'CE1', 'NE2'],
'F': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ'],
'R': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'NE', 'CZ', 'NH1', 'NH2'],
'Y': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'OH'],
'W': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE2', 'CE3', 'NE1', 'CZ2', 'CZ3', 'CH2'],
'X': ['N', 'CA', 'C', 'O']}     # unknown amino acid

aa_short2long = {'C': 'CYS', 'D': 'ASP', 'S': 'SER', 'Q': 'GLN', 'K': 'LYS', 'I': 'ILE',
                 'P': 'PRO', 'T': 'THR', 'F': 'PHE', 'N': 'ASN', 'G': 'GLY', 'H': 'HIS',
                 'L': 'LEU', 'R': 'ARG', 'W': 'TRP', 'A': 'ALA', 'V': 'VAL', 'E': 'GLU',
                 'Y': 'TYR', 'M': 'MET'}

aa_long2short = {aa_long: aa_short for aa_short, aa_long in aa_short2long.items()}
aa_long2short['MSE'] = 'M'


def lig_atom_featurizer(mol):
    ringinfo = mol.GetRingInfo()
    atom_features_list = []
    for idx, atom in enumerate(mol.GetAtoms()):
        atom_features_list.append([
            safe_index(allowable_features['possible_atomic_num_list'], atom.GetAtomicNum()),
            allowable_features['possible_chirality_list'].index(str(atom.GetChiralTag())),
            safe_index(allowable_features['possible_degree_list'], atom.GetTotalDegree()),
            safe_index(allowable_features['possible_formal_charge_list'], atom.GetFormalCharge()),
            safe_index(allowable_features['possible_implicit_valence_list'], atom.GetValence(Chem.ValenceType.IMPLICIT)),
            safe_index(allowable_features['possible_numH_list'], atom.GetTotalNumHs()),
            safe_index(allowable_features['possible_number_radical_e_list'], atom.GetNumRadicalElectrons()),
            safe_index(allowable_features['possible_hybridization_list'], str(atom.GetHybridization())),
            allowable_features['possible_is_aromatic_list'].index(atom.GetIsAromatic()),
            safe_index(allowable_features['possible_numring_list'], ringinfo.NumAtomRings(idx)),
            allowable_features['possible_is_in_ring3_list'].index(ringinfo.IsAtomInRingOfSize(idx, 3)),
            allowable_features['possible_is_in_ring4_list'].index(ringinfo.IsAtomInRingOfSize(idx, 4)),
            allowable_features['possible_is_in_ring5_list'].index(ringinfo.IsAtomInRingOfSize(idx, 5)),
            allowable_features['possible_is_in_ring6_list'].index(ringinfo.IsAtomInRingOfSize(idx, 6)),
            allowable_features['possible_is_in_ring7_list'].index(ringinfo.IsAtomInRingOfSize(idx, 7)),
            allowable_features['possible_is_in_ring8_list'].index(ringinfo.IsAtomInRingOfSize(idx, 8)),
        ])
    return np.array(atom_features_list) + 1 # +1 because 0 is the padding index, needed for nn.Embedding


def generate_conformer(mol):
    ps = AllChem.ETKDGv3()
    failures, id = 0, -1
    while failures < 3 and id == -1:
        if failures > 0:
            print(f'rdkit coords could not be generated. trying again {failures}.')
        id = AllChem.EmbedMolecule(mol, ps)
        failures += 1
    if id == -1:
        print('rdkit coords could not be generated without using random coords. using random coords now.')
        ps.useRandomCoords = True
        AllChem.EmbedMolecule(mol, ps)
        AllChem.MMFFOptimizeMolecule(mol, confId=0)
        return True
    return False


def generate_multiple_conformers(mol, num_conformers):
    ps = AllChem.ETKDGv3()
    failures, ids = 0, []
    while failures < 3 and len(ids) == 0:
        if failures > 0:
            print(f'rdkit coords could not be generated. trying again {failures}.')
        ids = AllChem.EmbedMultipleConfs(mol, num_conformers, ps)
        ids = [id for id in ids]
        ids = [id for id in ids if id != -1]
        print(ids)
        failures += 1
    if len(ids) == 0:
        print('rdkit coords could not be generated without using random coords. using random coords now.')
        ps.useRandomCoords = True
        # AllChem.EmbedMolecule(mol, ps)
        AllChem.EmbedMultipleConfs(mol, min(num_conformers, 10), ps)
        for i in range(mol.GetNumConformers()):
            AllChem.MMFFOptimizeMolecule(mol, confId=i)
        return True
    return False


def rec_residue_featurizer(rec):
    feature_list = []
    for residue in rec.get_residues():
        feature_list.append([safe_index(allowable_features['possible_amino_acids'], residue.get_resname())])
    return torch.tensor(feature_list, dtype=torch.float32)  # (N_res, 1)


def safe_index(l, e):
    """ Return index of element e in list l. If e is not present, return the last index """
    try:
        return l.index(e)
    except:
        return len(l) - 1


def parse_receptor(pdbid, pdbbind_dir, dataset_type):
    rec = parsePDB(pdbid, pdbbind_dir, dataset_type)
    return rec


def parsePDB(pdbid, pdbbind_dir, dataset_type):
    if dataset_type == 'lpce':
        rec_path = os.path.join(pdbbind_dir, pdbid, f'{pdbid}.pdb')
    elif dataset_type == 'pdbbind' or dataset_type == 'pdbbind_conf' or \
        dataset_type == 'dockgen' or dataset_type == 'dockgen_full' or dataset_type == 'dockgen_full_conf':
        rec_path = os.path.join(pdbbind_dir, pdbid, f'{pdbid}_protein_processed.pdb')
    elif dataset_type == 'moad':
        rec_path = os.path.join(pdbbind_dir, 'pdb_protein', f'{pdbid.split("_superlig")[0]}_protein.pdb')
    elif dataset_type == 'posebusters' or dataset_type == 'astex' \
        or dataset_type == 'posebusters_conf' or dataset_type == 'astex_conf':
        rec_path = os.path.join(pdbbind_dir, pdbid, f'{pdbid}_protein.pdb')
    else:
        raise ValueError(f'Unknown dataset type: {dataset_type}')
    protein = parse_pdb_from_path(rec_path)
    return protein


def parse_pdb_from_path(path):
    pdb = prody.parsePDB(path)
    return pdb


def get_coords(prody_pdb):
    resindices = sorted(set(prody_pdb.ca.getResindices()))
    coords = np.full((len(resindices), 14, 3), np.nan)
    atom_names = np.full((len(resindices), 14), np.nan).astype(object)
    for i, resind in enumerate(resindices):
        sel = prody_pdb.select(f'resindex {resind}')
        resname = sel.getResnames()[0]
        for j, name in enumerate(atom_order[aa_long2short[resname] if resname in aa_long2short else 'X']):
            sel_resnum_name = sel.select(f'name {name}')
            if sel_resnum_name is not None:
                coords[i, j, :] = sel_resnum_name.getCoords()[0]
                atom_names[i, j] = sel_resnum_name.getElements()[0]
            else:
                coords[i, j, :] = [np.nan, np.nan, np.nan]
                atom_names[i, j] = 'X'
    return coords, atom_names


def read_mols(pdbbind_dir, name, remove_hs=False):
    ligs = []
    for file in os.listdir(os.path.join(pdbbind_dir, name)):
        if file.endswith(".sdf") and 'rdkit' not in file:
            lig = read_molecule(os.path.join(pdbbind_dir, name, file), remove_hs=remove_hs, sanitize=True)
            if lig is None and os.path.exists(os.path.join(pdbbind_dir, name, file[:-4] + ".mol2")):  # read mol2 file if sdf file cannot be sanitized
                print('Using the .sdf file failed. We found a .mol2 file instead and are trying to use that.')
                lig = read_molecule(os.path.join(pdbbind_dir, name, file[:-4] + ".mol2"), remove_hs=remove_hs, sanitize=True)
            if lig is not None:
                ligs.append(lig)
    return ligs


def read_molecule(molecule_file, sanitize=False, calc_charges=False, remove_hs=False):
    """
    Read a molecular structure from a file and optionally process it.

    This function reads a molecular structure from various file formats and provides options to sanitize the molecule,
    calculate Gasteiger charges, and remove hydrogen atoms.

    Parameters:
    molecule_file (str): Path to the molecular structure file. Supported formats are .mol2, .sdf, .pdbqt, and .pdb.
    sanitize (bool): If True, sanitize the molecule (default: False).
    calc_charges (bool): If True, calculate Gasteiger charges for the molecule (default: False).
    remove_hs (bool): If True, remove hydrogen atoms from the molecule (default: False).

    Returns:
    RDKit.Chem.Mol or None: The RDKit molecule object if the molecule is successfully read and processed, None otherwise.

    Raises:
    ValueError: If the file format is not supported.

    Notes:
    - Sanitization ensures the molecule's valence states are correct and that the structure is reasonable.
    - Gasteiger charges are partial charges used for computational chemistry methods.
    - Removing hydrogen atoms can be useful for simplifying the molecule, though it may lose information.

    Example:
    >>> from rdkit import Chem
    >>> mol = read_molecule('molecule.mol2', sanitize=True, calc_charges=True, remove_hs=True)
    >>> if mol:
    >>>     print(Chem.MolToSmiles(mol))
    """
    if molecule_file.endswith('.mol2'):
        mol = Chem.MolFromMol2File(molecule_file, sanitize=False, removeHs=False)
    elif molecule_file.endswith('.sdf'):
        supplier = Chem.SDMolSupplier(molecule_file, sanitize=False, removeHs=False)
        mol = supplier[0]
    elif molecule_file.endswith('.gninatypes'):
        sdf_base_file, gnina_mol_file = molecule_file.split('@')
        # read reference molecule file
        mol = Chem.MolFromPDBFile(sdf_base_file, sanitize=False, removeHs=False)

        # read correct docked coordinates
        buf = open(gnina_mol_file, 'rb').read()
        n = int(len(buf) / 4)
        vals = np.array(struct.unpack('f' * n, buf)).reshape(int(n / 4), 4)

        # replace mol coordinates with correct ones
        conf = mol.GetConformer()
        for i in range(mol.GetNumAtoms()):
            x, y, z, _ = vals[i]
            conf.SetAtomPosition(i, Point3D(x, y, z))
    elif molecule_file.endswith('.pdbqt'):
        with open(molecule_file) as file:
            pdbqt_data = file.readlines()
        pdb_block = ''
        for line in pdbqt_data:
            pdb_block += '{}\n'.format(line[:66])
        mol = Chem.MolFromPDBBlock(pdb_block, sanitize=False, removeHs=False)
    elif molecule_file.endswith('.pdb'):
        mol = Chem.MolFromPDBFile(molecule_file, sanitize=False, removeHs=False)
    else:
        raise ValueError('Expect the format of the molecule_file to be '
                         'one of .mol2, .sdf, .pdbqt and .pdb, got {}'.format(molecule_file))

    try:
        if sanitize or calc_charges:
            try:
                Chem.SanitizeMol(mol)
            except Exception as e:
                print("RDKit was unable to sanitize the molecule.")

        if calc_charges:
            # Compute Gasteiger charges on the molecule.
            try:
                AllChem.ComputeGasteigerCharges(mol)
            except:
                warnings.warn('Unable to compute charges for the molecule.')

        if remove_hs:
            try:
                mol = Chem.RemoveAllHs(mol, sanitize=sanitize)
            except Exception as e:
                print("RDKit was unable to remove hydrogen atoms from the molecule.")
                mol = Chem.RemoveAllHs(mol, sanitize=False)

    except Exception as e:
        print(e)
        print("RDKit was unable to read the molecule.")
        return None

    return mol


def read_sdf_with_multiple_confs(molecule_file, sanitize=False, calc_charges=False, remove_hs=False):
    supplier = Chem.SDMolSupplier(molecule_file, sanitize=False, removeHs=False)
    mols = []
    for mol in supplier:
        try:
            if sanitize or calc_charges:
                Chem.SanitizeMol(mol)

            if calc_charges:
                # Compute Gasteiger charges on the molecule.
                try:
                    AllChem.ComputeGasteigerCharges(mol)
                except:
                    warnings.warn('Unable to compute charges for the molecule.')

            if remove_hs:
                mol = Chem.RemoveAllHs(mol, sanitize=sanitize)
        except Exception as e:
            print(e)
            print("RDKit was unable to read the molecule.")
            mol = None

        if mol is not None:
            mols.append(mol)
    return mols


def extract_receptor_structure_prody(rec, lig, sequences_to_embeddings):
    """
    Extract and process the structure of a receptor in the context of its interaction with a ligand.

    This function extracts the atomic coordinates of amino acids in the receptor, particularly focusing on
    backbone atoms (C-alpha, N, and C). It filters out non-amino acid residues and identifies the chains
    that are valid (contain amino acids) and those that are in close proximity to the ligand.

    Parameters:
    rec (Bio.PDB.Structure.Structure): The receptor structure, typically a Bio.PDB structure object.
    lig (rdkit.Chem.Mol): The ligand molecule, typically an RDKit molecule object.
    lm_embedding_chains (list of np.ndarray, optional): Optional embeddings for each chain from a language model.
        If provided, it should have the same number of chains as the receptor structure.

    Returns:
    tuple:
        - rec (Bio.PDB.Structure.Structure): The modified receptor structure with invalid chains removed.
        - c_alpha_coords (np.ndarray): A numpy array of shape (n_residues, 3) containing the C-alpha atom coordinates of
          valid residues.
        - lm_embeddings (np.ndarray or None): A concatenated numpy array of the valid language model embeddings for the chains,
          if lm_embedding_chains is provided. Otherwise, None.
    """
    if lig is not None:
        conf = lig.GetConformer()
        lig_coords = conf.GetPositions()
    seq = rec.ca.getSequence()
    coords, atom_names = get_coords(rec)

    res_chain_ids = rec.ca.getChids()
    res_seg_ids = rec.ca.getSegnames()
    res_chain_ids = np.asarray([s + c for s, c in zip(res_seg_ids, res_chain_ids)])
    chain_ids = np.unique(res_chain_ids)
    seq = np.array([s for s in seq])

    sequences = []
    lm_embeddings = []
    c_alpha_coords = []
    full_coords = []
    full_atom_names = []
    min_distances_to_lig = []
    chain_distances = {}
    for i, chain_id in enumerate(chain_ids):
        chain_mask = res_chain_ids == chain_id
        chain_seq = ''.join(seq[chain_mask])
        chain_coords = coords[chain_mask]
        
        chain_atom_names = atom_names[chain_mask]

        nonempty_coords = chain_coords.reshape(-1, 3)
        notnan_mask = np.isnan(nonempty_coords).sum(axis=1) == 0
        nonempty_coords = nonempty_coords[notnan_mask]
        
        chain_atom_names = chain_atom_names.reshape(-1)
        chain_atom_names = chain_atom_names[notnan_mask]

        min_dist_to_lig = 0
        if lig is not None:
            distances = np.linalg.norm(lig_coords[None] - nonempty_coords[:, None], axis=-1)
            min_dist_arr = distances.min(axis=0)
            min_dist_to_lig = distances.min()
            chain_distances[chain_id] = min_dist_to_lig

        if min_dist_to_lig < 4.5:
        # if min_dist_to_lig < 10:
            embeddings, tokenized_seq = sequences_to_embeddings[chain_seq]
            sequences.append(tokenized_seq)
            lm_embeddings.append(embeddings)
            c_alpha_coords.append(chain_coords[:, 1].astype(np.float32))
            full_coords.append(nonempty_coords)
            full_atom_names.append(chain_atom_names)
            if lig is not None:
                min_distances_to_lig.append(min_dist_arr)

    if len(c_alpha_coords) == 0:
        print('NO VALID CHAIN!!!')
        print(chain_distances)
        return None, None, None, None, None, None

    chain_lengths = [len(seq) for seq in sequences]
    c_alpha_coords = np.concatenate(c_alpha_coords, axis=0)  # [n_residues, 3]
    full_coords = np.concatenate(full_coords, axis=0) # [n_protein_atoms, 3]
    full_atom_names = np.concatenate(full_atom_names, axis=0)
    lm_embeddings = np.concatenate(lm_embeddings, axis=0)
    sequences = np.concatenate(sequences, axis=0)

    if lig is not None:
        min_distances_to_lig = np.stack(min_distances_to_lig)
        min_distances_to_lig = min_distances_to_lig.min(axis=0)

        distance_cutoff = 5.
        is_buried_threshold = 0.3 # -100
        buried_atoms_mask = min_distances_to_lig <= distance_cutoff
        fraction_buried = buried_atoms_mask.mean()

        if fraction_buried < is_buried_threshold:
            print(f'Ligand is not buried (fraction_buried = {fraction_buried})')
            return None, None, None, None, None, None
        
    return c_alpha_coords, lm_embeddings, sequences, chain_lengths, full_coords, full_atom_names


def extract_chain_structure(rec, lm_embedding_chains=None, sequence_chains=None):
    """
    Extract and process the structure of a receptor.

    This function extracts the atomic coordinates of amino acids in the receptor, particularly focusing on
    backbone atoms (C-alpha, N, and C). It filters out non-amino acid residues and identifies the chains
    that are valid (contain amino acids).

    Parameters:
    rec (Bio.PDB.Structure.Structure): The receptor structure, typically a Bio.PDB structure object.
    lm_embedding_chains (list of np.ndarray, optional): Optional embeddings for each chain from a language model.
        If provided, it should have the same number of chains as the receptor structure.

    Returns:
    tuple:
        - rec (Bio.PDB.Structure.Structure): The modified receptor structure with invalid chains removed.
        - coords (list of np.ndarray): A list of numpy arrays containing the coordinates of valid residues. Each array
          corresponds to a residue and has shape (n_atoms, 3).
        - c_alpha_coords (np.ndarray): A numpy array of shape (n_residues, 3) containing the C-alpha atom coordinates of
          valid residues.
        - n_coords (np.ndarray): A numpy array of shape (n_residues, 3) containing the N atom coordinates of valid residues.
        - c_coords (np.ndarray): A numpy array of shape (n_residues, 3) containing the C atom coordinates of valid residues.
        - lm_embeddings (np.ndarray or None): A concatenated numpy array of the valid language model embeddings for the chains,
          if lm_embedding_chains is provided. Otherwise, None.

    Notes:
    - Only amino acid residues are considered valid; water molecules and other non-standard residues are ignored.
    - A chain is considered valid if it contains at least one amino acid residue.
    - Chains that do not contain valid residues are detached from the receptor structure.
    - The function ensures that all residues have C-alpha, N, and C atoms for proper extraction.
    """
    c_alpha_coords = []
    valid_chain_ids = []
    lengths = []
    for i, chain in enumerate(rec):
        chain_coords = []  # num_residues, num_atoms, 3
        chain_c_alpha_coords = []
        count = 0
        invalid_res_ids = []
        for residue in chain:
            if residue.get_resname() == 'HOH':
                invalid_res_ids.append(residue.get_id())
                continue
            residue_coords = []
            c_alpha, n, c = None, None, None
            for atom in residue:
                if atom.name == 'CA':
                    c_alpha = list(atom.get_vector())
                if atom.name == 'N':
                    n = list(atom.get_vector())
                if atom.name == 'C':
                    c = list(atom.get_vector())
                residue_coords.append(list(atom.get_vector()))

            if c_alpha is not None and n is not None and c is not None:
                # only append residue if it is an amino acid and not some weird molecule that is part of the complex
                chain_c_alpha_coords.append(c_alpha)
                chain_coords.append(np.array(residue_coords))
                count += 1
            else:
                invalid_res_ids.append(residue.get_id())
        for res_id in invalid_res_ids:
            chain.detach_child(res_id)

        lengths.append(count)
        c_alpha_coords.append(np.array(chain_c_alpha_coords))
        if not count == 0:
            valid_chain_ids.append(chain.get_id())

    valid_c_alpha_coords = []
    valid_lengths = []
    invalid_chain_ids = []
    valid_lm_embeddings = []
    valid_sequences = []
    for i, chain in enumerate(rec):
        if chain.get_id() in valid_chain_ids:
            valid_c_alpha_coords.append(c_alpha_coords[i])
            if lm_embedding_chains is not None:
                if i >= len(lm_embedding_chains):
                    raise ValueError('Encountered valid chain id that was not present in the LM embeddings')
                valid_lm_embeddings.append(lm_embedding_chains[i])
                valid_sequences.append(sequence_chains[i])
            valid_lengths.append(lengths[i])
        else:
            invalid_chain_ids.append(chain.get_id())

    c_alpha_coords = np.concatenate(valid_c_alpha_coords, axis=0)  # [n_residues, 3]
    lm_embeddings = np.concatenate(valid_lm_embeddings, axis=0) if lm_embedding_chains is not None else None
    sequence = np.concatenate(valid_sequences, axis=0) if sequence_chains is not None else None

    assert sum(valid_lengths) == len(c_alpha_coords)
    return c_alpha_coords, lm_embeddings, sequence


def extract_ligand(input_pdb, output_ligand_pdb, ligand_resname="LIG"):
    if input_pdb.endswith(".cif"):
        parser = MMCIFParser(QUIET=True)
    else:
        parser = PDBParser(QUIET=True)
    structure = parser.get_structure("aligned", input_pdb)

    # First change the residue names
    for model in structure:
        for chain in model:
            for residue in chain:
                if residue.get_resname().strip().startswith(ligand_resname):
                    residue.resname = ligand_resname

    class LigandSelect(Select):
        def accept_residue(self, residue):
            return residue.get_resname().strip().startswith(ligand_resname)
        
    io = PDBIO()
    io.set_structure(structure)
    io.save(output_ligand_pdb, LigandSelect())


def extract_protein(input_pdb, output_protein_pdb):
    if input_pdb.endswith(".cif"):
        parser = MMCIFParser(QUIET=True)
    else:
        parser = PDBParser(QUIET=True)
    structure = parser.get_structure("aligned", input_pdb)
    class ProteinSelect(Select):
        def accept_residue(self, residue):
            return not residue.get_resname().strip().startswith("LIG")
    io = PDBIO()
    io.set_structure(structure)
    io.save(output_protein_pdb, ProteinSelect())

    
def remove_edges_and_get_smiles(molecule, edges_to_remove):
    """
    Remove specified edges from a molecule and return the SMILES of resulting components.

    Args:
    molecule (rdkit.Chem.Mol): The molecule from which edges will be removed.
    edges_to_remove (list of tuples): A list of tuples where each tuple contains
                                       the indices of the atoms to remove the bond between.

    Returns:
    list of str: A list of SMILES strings representing the resulting components.
    """
    editable_mol = Chem.EditableMol(molecule)
    
    # Remove specified bonds
    if len(edges_to_remove) > 0:
        for bond in edges_to_remove:
            editable_mol.RemoveBond(int(bond[0]), int(bond[1]))
    
    # Get the new molecule after all removals
    new_mol = editable_mol.GetMol()
    
    # Split the molecule into components
    components = Chem.GetMolFrags(new_mol, asMols=True, sanitizeFrags=False)
    components_idx = Chem.GetMolFrags(new_mol, asMols=False, sanitizeFrags=False)

    # Convert each component to SMILES and return the list
    results_list = [(Chem.MolToSmiles(component), component_idx) for component, component_idx in zip(components, components_idx)]
    
    return results_list
