import copy
import numpy as np
import warnings

import networkx as nx
from networkx.algorithms import isomorphism as iso
from rdkit import Chem
from Bio.PDB import PDBParser, MMCIFParser, PDBIO, Select
from Bio.PDB.PDBExceptions import PDBConstructionWarning

from flowdock.utils.preprocessing import read_molecule


def pdb2sdf(input_pdb_file, output_sdf_file, ref_mol_path):
    """Convert PDB to SDF with correct bond orders.
    
    Args:
        input_pdb_file: Path to input PDB file
        output_sdf_file: Path to output SDF file
        ref_mol: Reference RDKit molecule with correct bond orders (optional)
    """
    # Read the molecule
    lig_mol = read_molecule(input_pdb_file)
    
    ref_mol = read_molecule(ref_mol_path)
    # Template the bond orders from reference molecule
    try:
        lig_mol_new = template_bond_orders(lig_mol, ref_mol)
    except Exception as e:
        lig_mol_new = None
    if lig_mol_new is None:
        lig_mol_new = restore_atom_order(ref_mol, lig_mol)
    
    # Write the molecule
    writer = Chem.SDWriter(output_sdf_file)
    writer.write(lig_mol_new, confId=0)
    writer.close()


def extract_ligand(input_pdb, output_ligand_pdb, ligand_resname="LIG"):
    parser = MMCIFParser(QUIET=True) if input_pdb.endswith(".cif") else PDBParser(QUIET=True)
    with warnings.catch_warnings():
        warnings.simplefilter('ignore', PDBConstructionWarning)
        structure = parser.get_structure("aligned", input_pdb)

    class LigandSelect(Select):
        def accept_residue(self, residue):
            return residue.get_resname().strip().startswith(ligand_resname)

    def sanitize_name(name):
        base = name.replace('_', '')
        return base[:4]

    for model in structure:
        for chain in model:
            if not chain.id.strip():
                chain.id = 'A'
            elif len(chain.id) > 1:
                chain.id = chain.id[0]
            lig_idx = 0
            for residue in chain:
                if residue.get_resname().strip().startswith(ligand_resname):
                    residue.resname = ligand_resname
                    lig_idx += 1
                    residue.id = ('H_', lig_idx, ' ')
                    for atom in residue:
                        base = sanitize_name(atom.get_name())
                        atom.name = base
                        atom.fullname = f"{base:>4}"  # 4-char, right-justified
                        elem = (atom.element or '').strip()
                        if not elem:
                            letters = ''.join(c for c in base if c.isalpha())
                            elem = (letters[:2] or base[:1]).upper()
                        atom.element = elem[:2].upper()

    io = PDBIO()
    io.set_structure(structure)
    io.save(output_ligand_pdb, LigandSelect())


def template_bond_orders(mol, ref_mol):
    """Copy bond orders from reference molecule to target molecule."""
    # First sanitize the molecules (but skip kekulization)
    for m in [mol, ref_mol]:
        Chem.SanitizeMol(m, sanitizeOps=Chem.SANITIZE_ALL^Chem.SANITIZE_KEKULIZE)
    
    # Get matching between molecules
    matches = mol.GetSubstructMatches(ref_mol, uniquify=False)
    if not matches:
        matches = ref_mol.GetSubstructMatches(mol, uniquify=False)

    if matches:
        match = matches[0]
        if len(match) == mol.GetNumAtoms():  # Complete match
            # Copy bond orders from reference
            for bond in ref_mol.GetBonds():
                idx1, idx2 = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
                mol_bond = mol.GetBondBetweenAtoms(match[idx1], match[idx2])
                if mol_bond is not None:
                    mol_bond.SetBondType(bond.GetBondType())
    else:
        return None
    
    # Final cleanup and aromaticity perception
    Chem.SetAromaticity(mol)
    Chem.AssignStereochemistry(mol, cleanIt=True, force=True)
    try:
        Chem.SanitizeMol(mol)
    except:
        print("Warning: Final sanitization failed, but bond orders were assigned")
    return mol  


def restore_atom_order(ref_mol, pred_mol):
    ref_graph = mol_to_nx(ref_mol)
    pred_graph = mol_to_nx(pred_mol)
    def node_match(attr1, attr2):
        return attr1["atomic_num"] == attr2["atomic_num"]
    gm = iso.GraphMatcher(ref_graph, pred_graph, node_match=node_match)
    if gm.is_isomorphic():
        mapping = gm.mapping
        ordered_mapping = [mapping[i] for i in range(ref_mol.GetNumAtoms())]
        new_pred_mol = Chem.rdmolops.RenumberAtoms(pred_mol, ordered_mapping)
        ref_mol_copy = copy.deepcopy(ref_mol)
        ref_mol_copy.GetConformer().SetPositions(new_pred_mol.GetConformer().GetPositions())
        return ref_mol_copy
    return pred_mol


def mol_to_nx(mol):
    G = nx.Graph()
    for atom in mol.GetAtoms():
        G.add_node(atom.GetIdx(), atomic_num=atom.GetAtomicNum(), atom=atom)
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        G.add_edge(i, j)
    return G
    

def extract_protein(input_pdb, output_protein_pdb):
    if input_pdb.endswith(".cif"):
        parser = MMCIFParser(QUIET=True)
    else:
        parser = PDBParser(QUIET=True)
    structure = parser.get_structure("aligned", input_pdb)

    class ProteinSelect(Select):
        def accept_residue(self, residue):
            return not residue.get_resname().strip().startswith("LIG")
    io = PDBIO()
    io.set_structure(structure)
    io.save(output_protein_pdb, ProteinSelect())


def filter_protein_chains_by_ligand_distance(reference_protein_pdb_path, reference_ligand_path, 
                                             output_protein_pdb_path, distance_cutoff=10, return_all=False):
    # Read the reference ligand
    ligand_mol = read_molecule(reference_ligand_path)
    if ligand_mol is None:
        raise ValueError(f"Could not read ligand from {reference_ligand_path}")
    
    # Get ligand coordinates
    conf = ligand_mol.GetConformer()
    ligand_coords = conf.GetPositions()
    
    # Read the reference protein structure
    if reference_protein_pdb_path.endswith('.pdb'):
        parser = PDBParser(QUIET=True)
    else:
        parser = MMCIFParser(QUIET=True)
    protein_structure = parser.get_structure('protein', reference_protein_pdb_path)
    
    # Calculate distances between ligand and protein chains
    kept_chains = []
    chain_distances = {}
    chain_atom_counts = {}
    
    for chain in protein_structure.get_chains():
        chain_id = chain.get_id()
        chain_coords = []
        
        # Collect all atom coordinates from this chain
        for residue in chain:
            for atom in residue:
                if atom.get_name() != 'H':  # Skip hydrogens for distance calculation
                    chain_coords.append(atom.get_coord())
        
        if len(chain_coords) == 0:
            print(f"Chain {chain_id}: No atoms found")
            continue
            
        chain_coords = np.array(chain_coords)
        
        # Calculate minimum distance between ligand and this chain
        distances = np.linalg.norm(ligand_coords[:, None] - chain_coords[None, :], axis=-1)
        min_distance = distances.min()
        chain_distances[chain_id] = min_distance
        atoms_within_cutoff = np.sum(distances.min(axis=0) < distance_cutoff)
        chain_atom_counts[chain_id] = atoms_within_cutoff

    if return_all:
        kept_chains = [chain_id for chain_id, count in chain_atom_counts.items() if count > 0]
    else:
        primary_chain = max(chain_atom_counts.items(), key=lambda x: x[1])[0]
        kept_chains = [primary_chain]

    # Create a filtered structure with only the kept chains
    class ChainSelector(Select):
        def __init__(self, kept_chains):
            self.kept_chains = kept_chains
            
        def accept_chain(self, chain):
            return chain.get_id() in self.kept_chains
    
    for chain_id in kept_chains:
        if return_all:
            chain_output_protein_pdb_path = output_protein_pdb_path.replace('.pdb', f'_{chain_id}.pdb')
        else:
            chain_output_protein_pdb_path = output_protein_pdb_path
        # Save the filtered protein structure
        io = PDBIO()
        io.set_structure(protein_structure)
        io.save(chain_output_protein_pdb_path, ChainSelector([chain_id]))
    
    if len(kept_chains) == 0:
        print("WARNING: No chains were kept! Consider increasing the distance cutoff.")
        print("Chain distances to ligand:")
        for chain_id, distance in sorted(chain_distances.items()):
            print(f"  {chain_id}: {distance:.2f} Å")
    
    return kept_chains


def cif2pdb(input_cif, output_pdb, *, protein_only=False):
    """
    Convert mmCIF → PDB.
    - Normalizes chain IDs to 1 char (fallback 'A')
    - Renumbers residues per chain to 1..N (integers)
    - Resets atom serial numbers to 1..M (integers)
    - Optionally keeps only protein (drop HETATM) if protein_only=True
    """
    with warnings.catch_warnings():
        warnings.simplefilter('ignore', PDBConstructionWarning)
        structure = MMCIFParser(QUIET=True).get_structure("cif_in", input_cif)

    # Normalize chain IDs
    for model in structure:
        for chain in model:
            cid = chain.id
            if not cid or cid == " ":
                chain.id = "A"
            elif len(cid) > 1:
                chain.id = cid[0]

    # Renumber residues per chain
    for model in structure:
        for chain in model:
            new_idx = 0
            for residue in chain.get_residues():
                hetflag, _, icode = residue.id
                new_idx += 1
                hetflag = " " if hetflag is None else hetflag
                icode = " " if not icode else icode
                residue.id = (hetflag, new_idx, icode)

    # Reset atom serial numbers
    serial = 1
    for atom in structure.get_atoms():
        atom.serial_number = serial
        serial += 1

    class ProteinOnly(Select):
        def accept_residue(self, residue):
            hetflag, _, _ = residue.id
            return hetflag == " "

    io = PDBIO()
    io.set_structure(structure)
    io.save(output_pdb, ProteinOnly() if protein_only else Select())


def align_to_binding_site(
    predicted_protein: str,
    predicted_ligand: str,
    reference_protein: str,
    reference_ligand: str,
    aligned_ligand_path: str,
    aligned_protein_path: str,
    cutoff: float = 10.0,
):
    """Align the predicted protein-ligand complex to the reference complex
    using the reference protein's heavy atom ligand binding site residues.

    :param predicted_protein: File path to the predicted protein (PDB).
    :param predicted_ligand: File path to the optional predicted ligand
        (SDF).
    :param reference_protein: File path to the reference protein (PDB).
    :param reference_ligand: File path to the optional reference ligand
        (SDF).
    :param dataset: Dataset name (e.g., "dockgen", "casp15",
        "posebusters_benchmark", or "astex_diverse").
    :param aligned_filename_suffix: Suffix to append to the aligned
        files (default "_aligned").
    :param cutoff: Distance cutoff in Å to define the binding site
        (default 10.0).
    :param save_protein: Whether to save the aligned protein structure
        (default True).
    :param save_ligand: Whether to save the aligned ligand structure
        (default True).
    """
    from pymol import cmd

    # Initialize PyMOL
    cmd.delete("all")
    cmd.reinitialize()

    # Load structures
    cmd.load(reference_protein, "ref_protein")
    cmd.load(predicted_protein, "pred_protein")

    cmd.load(reference_ligand, "ref_ligand")
    cmd.load(predicted_ligand, "pred_ligand")

    # Group predicted protein and ligand(s) together for alignment
    cmd.create(
        "pred_complex",
        ("pred_protein or pred_ligand" if predicted_ligand is not None else "pred_protein"),
    )

    # Select heavy atoms in the reference protein
    cmd.select("ref_protein_heavy", "ref_protein and not elem H")

    # Select heavy atoms in the reference ligand(s)
    cmd.select("ref_ligand_heavy", "ref_ligand and not elem H")

    # Define the reference binding site(s) based on the reference ligand(s)
    # cmd.select("binding_site", f"(name N,CA,C,O) and ref_protein_heavy within {cutoff} of ref_ligand_heavy")
    cmd.select("binding_site", f"(backbone) and ref_protein_heavy within {cutoff} of ref_ligand_heavy")
    # cmd.select("binding_site", f"ref_protein_heavy within {cutoff} of ref_ligand_heavy")

    # Align the predicted protein to the reference binding site(s)
    alignment_result = cmd.align("pred_complex", "binding_site") #, cycles=0)

    # Apply the transformation to the individual objects
    cmd.matrix_copy("pred_complex", "pred_protein")
    cmd.matrix_copy("pred_complex", "pred_ligand")

    # Save the aligned ligand
    cmd.save(aligned_ligand_path, "pred_ligand")
    cmd.save(aligned_protein_path, "pred_protein")

    # Clean up
    cmd.delete("all")

    return alignment_result[0]


def align_to_binding_site_by_pocket(
    predicted_protein: str,
    predicted_ligand: str,
    reference_protein: str,
    reference_ligand: str,
    aligned_ligand_path: str,
    aligned_protein_path: str,
    cutoff: float = 10.0,
):
    """Align the predicted protein-ligand complex to the reference complex
    using the reference protein's heavy atom ligand binding site residues.

    :param predicted_protein: File path to the predicted protein (PDB).
    :param predicted_ligand: File path to the optional predicted ligand
        (SDF).
    :param reference_protein: File path to the reference protein (PDB).
    :param reference_ligand: File path to the optional reference ligand
        (SDF).
    :param dataset: Dataset name (e.g., "dockgen", "casp15",
        "posebusters_benchmark", or "astex_diverse").
    :param aligned_filename_suffix: Suffix to append to the aligned
        files (default "_aligned").
    :param cutoff: Distance cutoff in Å to define the binding site
        (default 10.0).
    :param save_protein: Whether to save the aligned protein structure
        (default True).
    :param save_ligand: Whether to save the aligned ligand structure
        (default True).
    """
    from pymol import cmd

    # Initialize PyMOL
    cmd.delete("all")
    cmd.reinitialize()

    # Load structures
    cmd.load(reference_protein, "ref_protein")
    cmd.load(predicted_protein, "pred_protein")
    cmd.load(reference_ligand, "ref_ligand")
    cmd.load(predicted_ligand, "pred_ligand")

    # Select heavy atoms in the reference protein
    cmd.select("ref_protein_heavy", "ref_protein and not elem H")
    cmd.select("pred_protein_heavy", "pred_protein and not elem H")

    # Select heavy atoms in the reference ligand(s)
    cmd.select("ref_ligand_heavy", "ref_ligand and not elem H")
    cmd.select("pred_ligand_heavy", "pred_ligand and not elem H")

    # Define the reference binding site(s) based on the reference ligand(s)
    cmd.select("ref_binding_site", f"(backbone) and ref_protein_heavy within {cutoff} of ref_ligand_heavy")

    # Define the predicted binding site(s) based on the reference ligand(s)
    cmd.select("pred_binding_site", f"(backbone) and pred_protein_heavy within {cutoff} of pred_ligand_heavy")

    # # Define binding sites using predicted ligand position for both proteins
    # # This ensures we align the same spatial region
    # cmd.select("ref_binding_site", f"ref_protein and backbone and not elem H within {cutoff} of ref_ligand")
    # cmd.select("pred_binding_site", f"pred_protein and backbone and not elem H within {cutoff} of pred_ligand")

    # Align the predicted protein to the reference binding site(s)
    alignment_result = cmd.align("pred_binding_site", "ref_binding_site", cycles=0)

    # Get the transformation matrix from the aligned protein
    transformation_matrix = cmd.get_object_matrix("pred_protein")

    # Apply the same transformation to the predicted ligand
    cmd.transform_selection("pred_ligand", transformation_matrix)

    # Save the aligned ligand
    cmd.save(aligned_ligand_path, "pred_ligand")
    cmd.save(aligned_protein_path, "pred_protein")

    # Clean up
    cmd.delete("all")

    return alignment_result[0]


def rmsd_func(pos, true_pos, mol_true, mol_pred):
    # try:
    #     rmsd = get_symmetry_rmsd(mol_true, true_pos, pos, mol2=mol_pred, return_permutation=False)
    # except:
    #     rmsd = np.sqrt(((pos - true_pos) ** 2).sum(axis=1).sum(axis=0) / pos.shape[0])
    #     print('Failed symm_rmsd')
    rmsd = np.sqrt(((pos - true_pos) ** 2).sum(axis=1).sum(axis=0) / pos.shape[0])
    return rmsd


def pdb2sdf(input_pdb_file, output_sdf_file, ref_mol_path):
    """Convert PDB to SDF with correct bond orders.
    
    Args:
        input_pdb_file: Path to input PDB file
        output_sdf_file: Path to output SDF file
        ref_mol: Reference RDKit molecule with correct bond orders (optional)
    """
    # Read the molecule
    lig_mol = read_molecule(input_pdb_file)
    
    ref_mol = read_molecule(ref_mol_path)
    # Template the bond orders from reference molecule
    try:
        lig_mol_new = template_bond_orders(lig_mol, ref_mol)
    except Exception as e:
        lig_mol_new = None
    if lig_mol_new is None:
        lig_mol_new = restore_atom_order(ref_mol, lig_mol)
    
    # Write the molecule
    writer = Chem.SDWriter(output_sdf_file)
    writer.write(lig_mol_new, confId=0)
    writer.close()
