import pickle
from tqdm import tqdm
from open_biomed.data.molecule import molecule_fingerprint_similarity
from open_biomed.tasks.aidd_tasks.protein_molecule_docking import pbcheck_single, aggregate_pb_results
from open_biomed.datasets.molecule_protein_dataset import CrossDocked
from open_biomed.utils.config import Config
from posebusters import PoseBusters

dataset = CrossDocked(
    cfg=Config.from_dict(
        path="./datasets/CrossDocked",
        debug=True,
    ),
    featurizer=None,
)
_, _, dataset = dataset.split()

if __name__ == "__main__":
    file = "./data/sample_results/train/molcraft_Mixed_CG_CFG_weighted_success"
    all_data = [[] for j in range(100000)]
    mx_cnt = 0
    for i in range(4):
        cnt = 0
        for j in range(50):
            data = pickle.load(open(f"{file}/{i}/preds_{j}.pkl", "rb"))
            for k, samples in enumerate(data):
                all_data[j * 2000 + k].extend(samples)
                cnt += 1
        print(i, cnt)

    buster = PoseBusters()
    num_mols = 0
    for i in tqdm(range(100000)):
        selected = []
        ligand = dataset.molecules[i]
        for j in range(len(all_data[i])):
            if all_data[i][j] is None or "." in all_data[i][j].smiles:
                continue
            select = True
            for mol in selected:
                if molecule_fingerprint_similarity(all_data[i][j], mol) > 0.6:
                    select = False
                    break
            pb_results = pbcheck_single([all_data[i][j]], ligand, dataset.proteins[i], buster)
            pb_results = aggregate_pb_results([pb_results])
            if not pb_results["pbvalid"]:
                select = False
            if select:
                selected.append(all_data[i][j])
        all_data[i] = selected
        num_mols += len(selected)
    print(num_mols)
    pickle.dump(all_data, open(f"{file}/filtered_preds.pkl", "wb"))