from open_biomed.datasets.molecule_protein_dataset import CrossDocked
from open_biomed.utils.config import Config
from open_biomed.data import Molecule
from open_biomed.utils.config import Config
from open_biomed.models.molecule.molcraft import MolCRAFTMoleculeFeaturizer
from open_biomed.tasks.aidd_tasks.structure_based_drug_design import calc_vina_molecule_metrics
import torch
import pickle
import os
import numpy as np

# ref model
path1 = "../data/sample_results/test/molcraft_ClassifierFreeGuidance_weighted_success/101"
# guide model w=1
path2 = "../data/sample_results/test/molcraft_ClassifierFreeGuidance_weighted_success/100"
# guide model w=1.5
# path3 = "../data/sample_results/test/molcraft_ClassifierFreeGuidance_weighted_success/1"

traj1 = pickle.load(open(os.path.join(path1, "trajs.pkl"), "rb"))
traj2 = pickle.load(open(os.path.join(path2, "trajs.pkl"), "rb"))
# traj3 = pickle.load(open(os.path.join(path3, "trajs.pkl"), "rb"))

featurizer = MolCRAFTMoleculeFeaturizer(pos_norm=2)
def decode_molecule(mol, pocket_center):
    num_atoms = mol["mu_pos"].shape[0] // 10
    j = 4
    cur_mol = {
        "pos": mol["mu_pos"][j * num_atoms: (j + 1) * num_atoms].cpu(),
        "atom_type": mol["theta_h"][j * num_atoms: (j + 1) * num_atoms].cpu().argmax(dim=-1),
    }
    mol = featurizer.decode(cur_mol, pocket_center)
    if mol is None:
        return cur_mol
    return mol

dataset = CrossDocked(
    cfg=Config.from_dict(
        path="./datasets/CrossDocked", 
        debug=True
    ),
    featurizer=None,
)
_, _, dataset = dataset.split()

i = 0
metrics = []
for row, (traj, label) in enumerate(zip([traj1, traj2], ["Ref", "Guide w=1"])):
    metrics.append([])
    for col in range(0, 10):
        mol = decode_molecule(traj[i][1][col * 10 + 9], np.mean(dataset.pockets[i].conformer, axis=0))
        if isinstance(mol, Molecule):
            print(mol)
            # import pdb; pdb.set_trace()
            metric = calc_vina_molecule_metrics(mol, dataset.proteins[i], calculate_vina_dock=False)
            if "vina_score" in metric:
                metrics[-1].append(metric["vina_score"])
            #     ax.text2D(0.5, -0.15, f'vina: {metrics["vina_dock"]:.2f}', transform=ax.transAxes, fontsize=8, ha='center')
print(metrics)