import pickle
from tqdm import tqdm
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem
from posecheck import PoseCheck
from open_biomed.datasets.molecule_protein_dataset import CrossDocked
from open_biomed.utils.config import Config

def get_molecule_force_field(mol, conf_id=None, force_field='mmff', **kwargs):
    """
    Get a force field for a molecule.
    Parameters
    ----------
    mol : RDKit Mol
        Molecule.
    conf_id : int, optional
        ID of the conformer to associate with the force field.
    force_field : str, optional
        Force Field name.
    kwargs : dict, optional
        Keyword arguments for force field constructor.
    """
    if force_field == 'uff':
        ff = AllChem.UFFGetMoleculeForceField(
            mol, confId=conf_id, **kwargs)
    elif force_field.startswith('mmff'):
        AllChem.MMFFSanitizeMolecule(mol)
        mmff_props = AllChem.MMFFGetMoleculeProperties(
            mol, mmffVariant=force_field)
        ff = AllChem.MMFFGetMoleculeForceField(
            mol, mmff_props, confId=conf_id, **kwargs)
    else:
        raise ValueError("Invalid force_field {}".format(force_field))
    return ff

def get_conformer_energies(mol, force_field='mmff'):
    """
    Calculate conformer energies.
    Parameters
    ----------
    mol : RDKit Mol
        Molecule.
    force_field : str, optional
        Force Field name.
    Returns
    -------
    energies : array_like
        Minimized conformer energies.
    """
    energies = []
    for conf in mol.GetConformers():
        try:
            ff = get_molecule_force_field(mol, conf_id=conf.GetId(), force_field=force_field)
            if ff is None:
                continue
        except Exception as e:
            print(e)
            continue
        energy = ff.CalcEnergy()
        energies.append(energy)
    energies = np.asarray(energies, dtype=float)
    return energies

if __name__ == "__main__":
    file = "../data/sample_results/test/molcraft_Mixed_CG_CFG_weighted_success/20/"
    # file = "../data/sample_results/test/molcraft_reference_weighted_success/0/"
    results = pickle.load(open(file + "preds.pkl", "rb"))
    metrics = pickle.load(open(file + "metrics.pkl", "rb"))
    dataset = CrossDocked(
        cfg=Config.from_dict(
            path="./datasets/CrossDocked",
            debug=True,
        ),
        featurizer=None,
    )
    _, _, dataset = dataset.split()
    
    strain, clash = [], []
    for i in tqdm(range(len(results))):
        protein_file = dataset.pockets[i].save_pdb()
        # for j in range(2):
        for j in range(len(results[i])):
            if results[i][j] is None or "." in results[i][j].smiles:
                continue
            results[i][j].save_sdf("./tmp/tmp_mol.sdf", overwrite=True)
            try:
                pc = PoseCheck(clash_tolerance=0.4)
                pc.load_ligands_from_sdf("./tmp/tmp_mol.sdf")
                pc.load_protein_from_pdb(protein_file)
                ener = pc.calculate_strain_energy()
                cur_clash = pc.calculate_clashes()
                # print(cur_clash)
                if len(ener) > 0 and not np.isnan(ener[0]):
                    strain.append(ener[0])
                if len(cur_clash) > 0 and not np.isnan(cur_clash[0]):
                    clash.append(cur_clash[0])
            except Exception as e:
                print(e)
                continue
            # print(strain[-1])
            # clash.append(pc.calculate_clashes()[0])
            """
            strain.extend(get_conformer_energies(results[i][j].rdmol))
            """

            # strain.append(get_conformer_energies(results[i][j].rdmol))
        # print(np.median(strain), np.mean(clash))
        print(np.median(strain))
        print(np.mean(clash))
    print(np.median(strain))
    print(np.mean(clash))
    pickle.dump(strain, open(file + "strain.pkl", "wb"))
    pickle.dump(clash, open(file + "clash.pkl", "wb"))