import os
import re
import subprocess
from typing import List, Optional, Tuple

from rdkit import Chem
from rdkit.Chem import AllChem, rdMolAlign


def add_hydrogens_with_fragment_ids(mol: Chem.Mol) -> Chem.Mol:
    """Add hydrogens to the molecule and assign fragment IDs to hydrogens based on heavy atoms."""
    mol_with_h = Chem.AddHs(mol)
    for atom in mol_with_h.GetAtoms():
        if atom.GetAtomicNum() == 1:  # Hydrogen atom
            neighbors = atom.GetNeighbors()
            if len(neighbors) == 1:
                heavy_atom = neighbors[0]
                global_frag_id = heavy_atom.GetIntProp("global_frag_id")
                frag_order = heavy_atom.GetIntProp("frag_order")
                atom.SetIntProp("global_frag_id", global_frag_id)
                atom.SetIntProp("frag_order", frag_order)
    return mol_with_h


def generate_conformers(
    mol: Chem.Mol, num_conformers: int = 50, prune_rms_thresh: float = 0.5
) -> List[int]:
    """Generate multiple conformers using ETKDG with pruning."""
    params = AllChem.ETKDGv3()
    params.pruneRmsThresh = prune_rms_thresh
    params.numThreads = 0  # Use all available cores
    params.randomSeed = 0xF00D
    params.maxAttempts = 1000
    params.useRandomCoords = True
    ids = AllChem.EmbedMultipleConfs(mol, numConfs=num_conformers, params=params)
    return list(ids)


def optimize_conformer(mol: Chem.Mol, conf_id: int, force_field: str = "MMFF") -> float:
    """Optimize a single conformer using MMFF or UFF."""
    if force_field.upper() == "MMFF":
        if not AllChem.MMFFHasAllMoleculeParams(mol):
            force_field = "UFF"
    if force_field.upper() == "MMFF":
        ff = AllChem.MMFFGetMoleculeForceField(
            mol, AllChem.MMFFGetMoleculeProperties(mol), confId=conf_id
        )
    else:
        ff = AllChem.UFFGetMoleculeForceField(mol, confId=conf_id)
    ff.Minimize()
    energy = ff.CalcEnergy()
    return energy


def optimize_conformers(
    mol: Chem.Mol, conf_ids: List[int], force_field: str = "MMFF"
) -> List[Tuple[int, float]]:
    """Optimize all conformers and return energies."""
    energies = []
    for conf_id in conf_ids:
        energy = optimize_conformer(mol, conf_id, force_field)
        energies.append((conf_id, energy))
    return energies


def cluster_conformers(
    mol: Chem.Mol,
    energies: List[Tuple[int, float]],
    energy_cutoff: float = 5.0,
    rmsd_threshold: float = 1.5,
) -> List[int]:
    """Cluster conformers based on energy cutoff and RMSD."""
    # Sort conformers by energy
    energies.sort(key=lambda x: x[1])
    lowest_energy = energies[0][1]
    # Filter conformers within the energy cutoff
    filtered_energies = [e for e in energies if e[1] - lowest_energy <= energy_cutoff]
    if not filtered_energies:
        # If no conformers are within the energy cutoff, keep the lowest energy conformer
        filtered_energies = [energies[0]]
    ckeep = [filtered_energies[0][0]]  # Start with the lowest energy conformer
    cgen = [conf_id for conf_id, energy in filtered_energies[1:]]

    for conf_id in cgen:
        discard = False
        for kept_conf_id in ckeep:
            rmsd = rdMolAlign.GetBestRMS(mol, mol, prbId=conf_id, refId=kept_conf_id)
            if rmsd < rmsd_threshold:
                discard = True
                break
        if not discard:
            ckeep.append(conf_id)
    return ckeep


def write_xyz_with_fragment_ids(
    mol: Chem.Mol, conf_id: int, filename: str, energy: Optional[float] = None
) -> Tuple[List[int], List[int]]:
    """Write XYZ file with optional SMILES in comment line."""

    global_frag_ids = []
    frag_orders = []
    with open(filename, "w") as f:
        conf = mol.GetConformer(conf_id)
        f.write(f"{mol.GetNumAtoms()}\n")
        smiles_Hs = Chem.MolToSmiles(mol)
        smiles_noHs = Chem.MolToSmiles(Chem.RemoveHs(mol))
        f.write(f"SMILES_Hs: {smiles_Hs}; SMILES_noHs: {smiles_noHs}; Energy: {energy}\n")

        for i, atom in enumerate(mol.GetAtoms()):
            pos = conf.GetAtomPosition(i)
            global_frag_id = atom.GetIntProp("global_frag_id")
            frag_order = atom.GetIntProp("frag_order")

            global_frag_ids.append(global_frag_id)
            frag_orders.append(frag_order)

            f.write(
                f"{atom.GetSymbol()}_{global_frag_id}_{frag_order} {pos.x:.8f} {pos.y:.8f} {pos.z:.8f}\n"
            )

    return global_frag_ids, frag_orders


def read_xyz_with_fragment_ids(
    xyz_file: str,
    mol_file: str,
    global_frag_ids: Optional[List[int]] = None,
    frag_ids: Optional[List[int]] = None,
) -> Tuple[Chem.Mol, int]:
    """
    Read a molecule's conformer from an XYZ file, preserving connectivity from a MOL file.

    Parameters:
        mol_file (str): Path to the MOL file containing connectivity.
        xyz_file (str): Path to the XYZ file containing updated coordinates and fragment IDs.

    Returns:
        Chem.Mol: Updated RDKit molecule with coordinates and fragment IDs.
    """
    # Read the molecule with connectivity from the MOL file

    mol = Chem.MolFromMolFile(mol_file, removeHs=False)

    if mol is None:
        raise ValueError("Failed to read the MOL file.")

    # Read the XYZ file
    with open(xyz_file, "r") as f:
        lines = f.readlines()

    num_atoms = int(lines[0])
    if num_atoms != mol.GetNumAtoms():
        raise ValueError("Number of atoms in XYZ and MOL files do not match.")

    conf = Chem.Conformer(num_atoms)

    # Parse the XYZ file and update atom positions and fragment IDs
    for i, line in enumerate(lines[2 : 2 + num_atoms]):
        tokens = line.strip().split()

        symbol = re.match(r"([A-Za-z]{1,2})", tokens[0]).group(1)
        x, y, z = map(float, tokens[1:4])  # Coordinates start at index 3

        # Ensure atom order matches: TODO: CREST and certain optimization algorithms WILL mess up the order, but XTB is fine
        atom = mol.GetAtomWithIdx(i)
        if atom.GetSymbol() != symbol:
            raise ValueError(
                f"Atom mismatch: MOL file has {atom.GetSymbol()} but XYZ file has {symbol} at index {i}."
            )

        # Set both fragment IDs as atom properties
        # Assume the order is the same (XTB atom order matches)
        if global_frag_ids is not None:
            atom.SetIntProp("global_frag_id", global_frag_ids[i])
        if frag_ids is not None:
            atom.SetIntProp("frag_order", frag_ids[i])

        # Set position in the conformer
        conf.SetAtomPosition(i, Chem.rdGeometry.Point3D(x, y, z))

    # Add the new conformer to the molecule
    mol.RemoveAllConformers()
    conf_id = mol.AddConformer(conf)

    return mol, conf_id


def run_xtb_optimization(
    xyz_filename: str,
    output_dir: str,
    mol_name: str,
    xtb_path: str = "xtb",
    xtb_version: str = "2",
    num_threads: int = 1,
):
    """Run XTB optimization on an XYZ file."""
    os.makedirs(output_dir, exist_ok=True)
    xyz_filename = os.path.abspath(xyz_filename)
    output_file = os.path.join(output_dir, f"{mol_name}.out")

    cmd = [
        xtb_path,
        xyz_filename,
        "--opt",
        "--namespace",
        mol_name,
        "--gfn",
        xtb_version,
        "--parallel",
        str(num_threads),
    ]
    with open(output_file, "a") as f:
        subprocess.run(cmd, check=True, cwd=output_dir, stdout=f, stderr=subprocess.STDOUT)

    return os.path.abspath(os.path.join(output_dir, f"{mol_name}.xtbopt.xyz"))


def process_molecules(
    molecule_list: List[Tuple[Chem.Mol, str]],
    do_xtb: bool = True,
    xtb_path: str = "xtb",
    xtb_version: str = "2",
    out_path: Optional[str] = None,
    num_conformers: int = 50,
    rmsd_threshold: float = 1.5,
    energy_cutoff: float = 10.0,
    num_threads: int = 4,
):
    """Process a list of molecules."""
    optimized_molecules = []

    for mol_idx, labeled_mol in enumerate(molecule_list):
        print(f"Processing molecule {mol_idx}")
        # Add hydrogens and assign fragment IDs to hydrogens
        combined_mol = add_hydrogens_with_fragment_ids(labeled_mol)

        # Generate multiple conformers
        conf_ids = generate_conformers(
            combined_mol, num_conformers=num_conformers, prune_rms_thresh=rmsd_threshold
        )

        # Optimize each conformer and get energies
        energies = optimize_conformers(combined_mol, conf_ids, force_field="MMFF")
        # Remove hydrogens for RMSD clustering
        mol_no_h = Chem.RemoveHs(combined_mol)

        # Cluster conformers based on RMSD and energy cutoff
        selected_conf_ids = cluster_conformers(
            mol_no_h, energies, energy_cutoff=energy_cutoff, rmsd_threshold=rmsd_threshold
        )

        # Run XTB optimization on selected conformers
        optimized_conformers_per_mol = []
        for conf_id in selected_conf_ids:
            # Write conformer to XYZ file
            mol_name = f"mol_{mol_idx}_conf_{conf_id}"
            xyz_filename = os.path.join(out_path, f"{mol_name}.xyz")
            global_frag_ids, frag_ids = write_xyz_with_fragment_ids(
                combined_mol, conf_id, xyz_filename
            )

            # Keep a copy of .mol file, we need to read it later for XTB
            mol_filename = os.path.abspath(os.path.join(out_path, f"{mol_name}.sdf"))
            Chem.MolToMolFile(combined_mol, mol_filename, confId=conf_id)

            if do_xtb:
                # Run XTB optimization
                output_dir = os.path.join(out_path, f"xtb_output_mol_{mol_idx}_conf_{conf_id}")
                optimized_xyz = run_xtb_optimization(
                    xyz_filename,
                    output_dir,
                    mol_name=mol_name,
                    xtb_path=xtb_path,
                    xtb_version=xtb_version,
                    num_threads=num_threads,
                )

                # Make sure XTB converged
                if os.path.isfile(optimized_xyz):
                    # Read optimized molecule
                    optimized_mol, optimized_conf_id = read_xyz_with_fragment_ids(
                        optimized_xyz, mol_filename, global_frag_ids, frag_ids
                    )

                    # Parse XTB energy from output
                    xtb_energy = parse_xtb_energy(optimized_xyz)
                    optimized_conformers_per_mol.append(
                        (optimized_mol, xtb_energy, optimized_conf_id)
                    )

            else:
                conf_mol = Chem.Mol(combined_mol)
                conf_mol.RemoveAllConformers()
                conf_mol.AddConformer(combined_mol.GetConformer(conf_id))
                # Add energy from MMFF optimization
                conf_energy = next(e[1] for e in energies if e[0] == conf_id)
                optimized_conformers_per_mol.append((conf_mol, conf_energy, conf_id))

        print(f"Finished geometry optimization for molecule {mol_idx}")

        # Cluster XTB-optimized conformers
        if do_xtb and len(optimized_conformers_per_mol) > 0:
            optimized_conformers_per_mol = cluster_xtb_optimized_conformers(
                optimized_conformers_per_mol,
                energy_cutoff=energy_cutoff,
                rmsd_threshold=rmsd_threshold,
            )

        # Save conformers into files for later use
        # NOTE: all conformers have Hs
        saved_info = save_conformers(
            optimized_conformers_per_mol, mol_idx, os.path.join(out_path, "final_conformers")
        )

        # Save final conformers: path and energies
        optimized_molecules.append(saved_info)

    return optimized_molecules


def parse_xtb_energy(output_file: str) -> Optional[float]:
    """Parse the total energy from XTB output file."""
    with open(output_file, "r") as f:
        for line in f:
            if "energy:" in line.lower():
                parts = line.strip().split()
                energy = float(parts[1])  # Energy in Hartree
                return energy * 627.509  # Convert Hartree to kcal/mol
    return None


def cluster_xtb_optimized_conformers(
    optimized_conformers: List[Tuple[Chem.Mol, float, int]],
    energy_cutoff: float = 5.0,
    rmsd_threshold: float = 1.0,
) -> List[Tuple[Chem.Mol, float, int]]:
    """Cluster XTB-optimized conformers based on energy cutoff and RMSD."""
    # Remove conformers where energy couldn't be parsed
    optimized_conformers = [
        (mol, energy, conf_id)
        for mol, energy, conf_id in optimized_conformers
        if energy is not None
    ]
    if not optimized_conformers:
        return []

    # Sort conformers by energy
    optimized_conformers.sort(key=lambda x: x[1])
    lowest_energy = optimized_conformers[0][1]

    # Filter conformers within the energy cutoff
    filtered_conformers = [c for c in optimized_conformers if c[1] - lowest_energy <= energy_cutoff]
    if not filtered_conformers:
        # If no conformers are within the energy cutoff, keep the lowest energy conformer
        filtered_conformers = [optimized_conformers[0]]

    # Strip hydrogen for RMSD clustering
    filtered_conformers = [
        (Chem.RemoveHs(mol), mol, energy, conf_id) for mol, energy, conf_id in filtered_conformers
    ]

    ckeep = [filtered_conformers[0]]  # Keep the first (lowest energy) conformer
    for mol_no_h, mol, energy, conf_id in filtered_conformers[1:]:
        discard = False
        for kept_mol_no_h, _, _, _ in ckeep:
            rmsd = rdMolAlign.GetBestRMS(mol_no_h, kept_mol_no_h)
            if rmsd < rmsd_threshold:
                discard = True
                break
        if not discard:
            ckeep.append((mol_no_h, mol, energy, conf_id))
    # Return only the original molecules and energies
    return [
        (mol, energy, conf_id) for _, mol, energy, conf_id in ckeep
    ]  # FIXME: when stripped_mol is returned, positions are different in the final file??


def save_conformers(
    conformers: List[Tuple[Chem.Mol, float]], mol_idx: int, out_path: str
) -> List[Tuple[str, float]]:
    """Save conformers to files, including fragment IDs."""
    os.makedirs(out_path, exist_ok=True)
    saved_info = []
    for idx, (mol, energy, conf_id) in enumerate(conformers):  # Unpack conf_id
        sdf_filename = os.path.join(out_path, f"mol_{mol_idx}_final_conf_{idx}.sdf")
        xyz_filename = os.path.join(out_path, f"mol_{mol_idx}_final_conf_{idx}.xyz")

        # Write SDF file with energy property
        mol.SetProp("_Energy", str(energy))
        with Chem.SDWriter(sdf_filename) as writer:
            writer.write(mol)

        # Write XYZ file with fragment IDs using the original conf_id
        write_xyz_with_fragment_ids(mol, conf_id=conf_id, filename=xyz_filename, energy=energy)

        saved_info.append((xyz_filename, energy))
    return saved_info

    return saved_info


def load_optimized_conformers(mol_idx: int, path: str) -> List[Tuple[Chem.Mol, float]]:
    """Load optimized conformers from files."""
    conformers = []
    idx = 0
    while True:
        sdf_filename = os.path.join(path, f"mol_{mol_idx}_final_conf_{idx}.sdf")
        if os.path.exists(sdf_filename):
            mol = Chem.MolFromMolFile(sdf_filename, removeHs=False)
            energy = float(mol.GetProp("_Energy"))
            conformers.append((mol, energy))
            idx += 1
        else:
            break
    return conformers
