#import rdkit.Chem as Chem
import rdkit.Chem.AllChem as Chem


BOND_LIST = [
    Chem.rdchem.BondType.SINGLE,
    Chem.rdchem.BondType.DOUBLE,
    Chem.rdchem.BondType.TRIPLE,
]

#RDLogger.DisableLog("rdApp.*")


def smiles_to_mol(smiles, kekulize=False, only_explicit_H=True):
    """
    Creates RDKit mol object from the input smiles strings, and runs some refinements.

    Returns:
        mol object if reading was successful, else None
    """
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        print(f'Failed to create molecule from smiles string: {smiles} ')
    if mol is not None and kekulize:
        # Kekulize = remove aromatic bonds and replace them with single or double bonds
        Chem.Kekulize(mol)

        # Double check if necessary
        # Kekulize, check valencies, set aromaticity, conjugation and hybridization
        Chem.SanitizeMol(mol)
    if mol is not None and only_explicit_H:
        # Currently, we don't model stereochemistry
        # However, it usually accounts for explicit H atoms
        # So we remove it when we want to model explicit H atoms
        Chem.RemoveStereochemistry(mol)
        mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol))
        mol = Chem.RemoveAllHs(mol)
        mol = Chem.AddHs(mol, explicitOnly=True)

    return mol


def mol_to_smiles(mol):
    return Chem.MolToSmiles(mol, kekuleSmiles=True)


def write_xyz_with_formal_charges(molecule, smiles, filename):
    atoms = molecule.GetAtoms()
    with open(filename, 'w') as f:
        f.write(f"{molecule.GetNumAtoms()}\n")
        f.write(f"{smiles}\n")
        
        for atom in atoms:
            symbol = atom.GetSymbol()
            formal_charge = atom.GetFormalCharge()
            coords = molecule.GetConformer().GetAtomPosition(atom.GetIdx())
            f.write(f"{symbol} {coords.x:.6f} {coords.y:.6f} {coords.z:.6f} {formal_charge}\n")


def extract_atom_symbols(all_smiles):
    """
    Reads a list of SMILES strings (extracted from all.txt) and extracts all atoms present in that dataset

    Returns: set of atoms
    """
    atom_set = set('H') # start with H as it is not explicitly present
    for smiles in all_smiles:
        try:
            mol = Chem.MolFromSmiles(smiles)
        except:
            continue
        for atom in mol.GetAtoms():
            atom_set.add(atom.GetSymbol())


def elem_to_charge_dict(atom_set):
    atom_nums = [Chem.Atom(atom).GetAtomicNum() for atom in atom_set]
    return dict(zip(atom_set, atom_nums))
    