from open_biomed.tasks.aidd_tasks.protein_molecule_docking import VinaDockTask
from open_biomed.datasets.molecule_protein_dataset import CrossDocked
from open_biomed.utils.config import Config
from open_biomed.data import calc_mol_rmsd, Molecule, mol_array_to_conformer
from tqdm import tqdm
import pickle
import os
import copy
from openbabel import pybel

def get_mol_from_pdbqt(pose_file: str) -> Molecule:
    sdf_file = pose_file.replace(".pdbqt", ".sdf")
    if not os.path.exists(sdf_file):
        mols = list(pybel.readfile("pdbqt", pose_file))
        if mols:
            mols[0].write("sdf", sdf_file, overwrite=True)
    return Molecule.from_sdf_file(sdf_file)
            
if __name__ == "__main__":
    os.system("rm ./tmp/mol_*")
    os.system("rm ./tmp/pocket_*")
    file = "../data/sample_results/test/molcraft_Mixed_CG_CFG_weighted_success/0/"
    results = pickle.load(open(file + "preds.pkl", "rb"))
    metrics = pickle.load(open(file + "metrics.pkl", "rb"))
    dataset = CrossDocked(
        cfg=Config.from_dict(
            path="./datasets/CrossDocked",
            debug=True,
        ),
        featurizer=None,
    )
    _, _, dataset = dataset.split()
    cnt, all = 0, 0
    task = VinaDockTask(mode="dock")
    task1 = VinaDockTask(mode="score")
    rmsds = []
    for i in tqdm(range(len(results))):
        pocket = dataset.proteins[i]
        # for j in range(len(results[i])):
        for j in range(min(200, len(results[i]))):
            if results[i][j] is None:
                continue
            all += 1
            _, pose_file = task.run(results[i][j], pocket, save_pose=True)
            _, pose_file1 = task1.run(results[i][j], pocket, save_pose=True)
            pose_file = pose_file[0]
            pose_file1 = pose_file1[0]
            if os.path.exists(pose_file) and os.path.exists(pose_file1):
                # Convert PDBQT pose file to SDF format using Open Babel
                # Requires openbabel/pybel to be installed   
                try:
                    orig = get_mol_from_pdbqt(pose_file1)
                    docked = get_mol_from_pdbqt(pose_file)
                    docked.rdmol = copy.deepcopy(orig.rdmol)
                    docked.rdmol.RemoveAllConformers()
                    docked.rdmol.AddConformer(mol_array_to_conformer(docked.conformer))
                    rmsd = calc_mol_rmsd(orig, docked)
                    rmsds.append(rmsd)
                    if i <= 2:
                        print(rmsd)
                    if rmsd < 2:
                        cnt += 1
                except Exception as e:
                    print(e)
                    print(f"failure at {i} {j}")
        os.system("rm ./tmp/mol_*")
        os.system("rm ./tmp/pocket_*")
        print(cnt / all)
    print(cnt, all)
    print(cnt / all)
    pickle.dump(rmsds, open(file + "recon_rmsd.pkl", "wb"))


            
