import pickle
import numpy as np
from tqdm import tqdm
import lmdb
import torch
import torch.nn.functional as F
from typing import List
import os
from torch_geometric.data import Data
from open_biomed.data import Molecule, Protein
from open_biomed.models.molecule.molcraft import MolCRAFTPocketFeaturizer
from open_biomed.utils.config import Config
from open_biomed.utils.misc import safe_index
from open_biomed.datasets.molecule_protein_dataset import CrossDocked
from open_biomed.tasks.aidd_tasks.protein_molecule_docking import VinaDockTask

def featurize_molecule(molecule: Molecule, pocket_center: torch.Tensor) -> torch.Tensor:
    MAP_ATOM_TYPE_AROMATIC_TO_INDEX = {
        (1, False): 0,
        (6, False): 1,
        (6, True): 2,
        (7, False): 3,
        (7, True): 4,
        (8, False): 5,
        (8, True): 6,
        (9, False): 7,
        (15, False): 8,
        (15, True): 9,
        (16, False): 10,
        (16, True): 11,
        (17, False): 12
    }
    MAP_INDEX_TO_ATOM_TYPE_AROMATIC = {v: k for k, v in MAP_ATOM_TYPE_AROMATIC_TO_INDEX.items()}
    molecule._add_rdmol()
    rdmol = molecule.rdmol
    node_feat_list = []
    for atom in rdmol.GetAtoms():
        key = (int(atom.GetAtomicNum()), bool(atom.GetIsAromatic()))
        if key not in MAP_ATOM_TYPE_AROMATIC_TO_INDEX:
            node_feat_list.append(0)
        else:
            node_feat_list.append(MAP_ATOM_TYPE_AROMATIC_TO_INDEX[key])
    node_feat = F.one_hot(torch.LongTensor(node_feat_list), num_classes=len(MAP_ATOM_TYPE_AROMATIC_TO_INDEX)).float()
    pos = torch.tensor(molecule.conformer, dtype=torch.float32)
    pos -= pocket_center
    pos /= 2
    
    return Data(**{
        "atom_feature": node_feat,
        "pos": pos,
    })

def normalize_labels(labels: List[float]) -> List[float]:
    # vina scores range from 0 to -16
    max_v, min_v = 0, -16
    max_q, min_q = 0.95, 0.01
    max_s, min_s = 1.0, 0.17
    return [
        (max_v - np.clip(labels[0], min_v, max_v)) / (max_v - min_v),
        (np.clip(labels[1], min_q, max_q) - min_q) / (max_q - min_q),
        (np.clip(labels[2], min_s, max_s) - min_s) / (max_s - min_s),
    ]

if __name__ == "__main__":
    dataset = CrossDocked(
        Config.from_dict(
            path="./data", 
            pocket_only=True,
        ),
        featurizer=None
    )
    pocket_featurizer = MolCRAFTPocketFeaturizer(pos_norm=2.0)

    path = "./data/sample_dataset.lmdb"
    if os.path.exists(path):
        os.remove(path)
    db = lmdb.open(path, map_size=100*1024*1024*1024, create=True, subdir=False, readonly=False)

    # NOTE: maybe we should perform posecheck
    with db.begin(write=True) as txn:
        for i in tqdm(range(len(dataset))):
            pocket = dataset.pockets[i]
            protein = Protein.from_pdb_file(dataset.proteins[i])
            molecule = dataset.molecules[i]
            pocket_center = torch.tensor(pocket.conformer, dtype=torch.float32).mean(dim=0)
            
            vina_task = VinaDockTask(
                docking_tool="autodock_vina",
                mode="score"
            )
            vina_score = vina_task.run(molecule, protein)[0][0]
            qed = molecule.calc_qed()
            sa = molecule.calc_sa()
            labels = torch.tensor(normalize_labels([vina_score, qed, sa]), dtype=torch.float32)
            mol_data = featurize_molecule(molecule, pocket_center)
            pocket_data = pocket_featurizer(pocket)
            txn.put(
                key=str(i).encode(),
                value=pickle.dumps({
                    "molecule": mol_data,
                    "pocket": pocket_data,
                    "labels": labels,
                })
            )
            
    db.close()

