import pickle
import numpy as np
from tqdm import tqdm
import lmdb
import torch
import torch.nn.functional as F
from typing import List
import os
from torch_geometric.data import Data
from open_biomed.data import Molecule
from open_biomed.models.molecule.molcraft import MolCRAFTPocketFeaturizer
from open_biomed.utils.misc import safe_index

def featurize_molecule(molecule: Molecule, pocket_center: torch.Tensor) -> torch.Tensor:
    MAP_ATOM_TYPE_AROMATIC_TO_INDEX = {
        (1, False): 0,
        (6, False): 1,
        (6, True): 2,
        (7, False): 3,
        (7, True): 4,
        (8, False): 5,
        (8, True): 6,
        (9, False): 7,
        (15, False): 8,
        (15, True): 9,
        (16, False): 10,
        (16, True): 11,
        (17, False): 12
    }
    MAP_INDEX_TO_ATOM_TYPE_AROMATIC = {v: k for k, v in MAP_ATOM_TYPE_AROMATIC_TO_INDEX.items()}
    molecule._add_rdmol()
    rdmol = molecule.rdmol
    node_feat_list = []
    for atom in rdmol.GetAtoms():
        key = (int(atom.GetAtomicNum()), bool(atom.GetIsAromatic()))
        if key not in MAP_ATOM_TYPE_AROMATIC_TO_INDEX:
            node_feat_list.append(0)
        else:
            node_feat_list.append(MAP_ATOM_TYPE_AROMATIC_TO_INDEX[key])
    node_feat = F.one_hot(torch.LongTensor(node_feat_list), num_classes=len(MAP_ATOM_TYPE_AROMATIC_TO_INDEX)).float()
    pos = torch.tensor(molecule.conformer, dtype=torch.float32)
    pos -= pocket_center
    pos /= 2
    
    return Data(**{
        "atom_feature": node_feat,
        "pos": pos,
    })

def normalize_labels(labels: List[float]) -> List[float]:
    # vina scores range from 0 to -16
    max_v, min_v = 0, -16
    max_q, min_q = 0.95, 0.01
    max_s, min_s = 1.0, 0.17
    return [
        (max_v - np.clip(labels[0], min_v, max_v)) / (max_v - min_v),
        (np.clip(labels[1], min_q, max_q) - min_q) / (max_q - min_q),
        (np.clip(labels[2], min_s, max_s) - min_s) / (max_s - min_s),
    ]

if __name__ == "__main__":
    print("Loading preds")
    preds1 = pickle.load(open("./data/filtered_preds_molcraft_train.pkl", "rb"))
    preds2 = pickle.load(open("./data/sample_results/train/molcraft_Mixed_CG_CFG_weighted_success/filtered_preds.pkl", "rb"))
    print("Loading metrics")
    metrics1 = pickle.load(open("./data/metrics_molcraft_train.pkl", "rb"))
    metrics2 = pickle.load(open("./data/sample_results/train/molcraft_Mixed_CG_CFG_weighted_success/metrics.pkl", "rb"))
    print("Loading dataset")
    dataset = pickle.load(open("./data/csd_train.pkl", "rb"))
    pocket_featurizer = MolCRAFTPocketFeaturizer(pos_norm=2.0)

    path = "./data/csd_train_sample/sample_dataset_round2.lmdb"
    if os.path.exists(path):
        os.remove(path)
    db = lmdb.open(path, map_size=100*1024*1024*1024, create=True, subdir=False, readonly=False)
    cnt = 0
    num_pockets = 0

    # NOTE: maybe we should perform posecheck
    with db.begin(write=True) as txn:
        for i in tqdm(range(len(preds1))):
            pocket = dataset.pockets[i]
            pocket_center = torch.tensor(pocket.conformer, dtype=torch.float32).mean(dim=0)
            if metrics1[i] is not None:
                assert len(metrics1[i]) == len(preds1[i])
            if metrics2[i] is not None:
                assert len(metrics2[i]) == len(preds2[i])
            
            # Add reference molecule first
            ref_mol_vina = metrics1[i][0]["ref_vina_score"]
            if ref_mol_vina > -1:
                continue
            num_pockets += 1
            ref_mol = dataset.molecules[i]
            if i <= 5:
                print(ref_mol, metrics1[i][0]["ref_qed"])
            # assert abs(ref_mol.calc_qed() - metrics[i][0]["ref_qed"]) < 1e-6
            # assert abs(ref_mol.calc_sa() - metrics[i][0]["ref_sa"]) < 1e-6
            ref_mol_data = featurize_molecule(ref_mol, pocket_center)
            pocket_data = pocket_featurizer(pocket)
            labels = torch.tensor(normalize_labels([metrics1[i][0]["ref_vina_score"], metrics1[i][0]["ref_qed"], metrics1[i][0]["ref_sa"]]), dtype=torch.float32)
            # labels = torch.tensor([all_ranks["vina_score"][0], all_ranks["qed"][0], all_ranks["sa"][0]], dtype=torch.long)
            txn.put(
                key=str(cnt).encode(),
                value=pickle.dumps({
                    "molecule": ref_mol_data,
                    "pocket": pocket_data,
                    "labels": labels,
                })
            )
            cnt += 1
            # Add the sampled molecules
            for j in range(len(preds1[i])):
                if metrics1[i][j]["vina_score"] > -1:
                    continue
                # assert abs(preds1[i][j].calc_qed() - metrics1[i][j]["qed"]) < 1e-6
                # assert abs(preds[i][j].calc_sa() - metrics[i][j]["sa"]) < 1e-6
                mol_data = featurize_molecule(preds1[i][j], pocket_center)
                labels = torch.tensor(normalize_labels([metrics1[i][j]["vina_score"], metrics1[i][j]["qed"], metrics1[i][j]["sa"]]), dtype=torch.float32)
                # labels = torch.tensor([all_ranks["vina_score"][j+1], all_ranks["qed"][j+1], all_ranks["sa"][j+1]], dtype=torch.long)
                txn.put(
                    key=str(cnt).encode(),
                    value=pickle.dumps({
                        "molecule": mol_data,
                        "pocket": pocket_data,
                        "labels": labels,
                    })
                )
                cnt += 1
            if metrics2[i] is not None:
                for j in range(len(preds2[i])):
                    if metrics2[i][j]["vina_score"] > -1:
                        continue
                    # assert abs(preds2[i][j].calc_qed() - metrics2[i][j]["qed"]) < 1e-6
                    mol_data = featurize_molecule(preds2[i][j], pocket_center)
                    labels = torch.tensor(normalize_labels([metrics2[i][j]["vina_score"], metrics2[i][j]["qed"], metrics2[i][j]["sa"]]), dtype=torch.float32)
                    txn.put(
                        key=str(cnt).encode(),
                        value=pickle.dumps({
                            "molecule": mol_data,
                            "pocket": pocket_data,
                            "labels": labels,
                        })
                    )
                    cnt += 1
            if num_pockets == 98200:
                print("Train cutoff: ", cnt)
            elif num_pockets == 98400:
                print("Val cutoff: ", cnt)
            
    print(f"Number of pockets: {num_pockets}")
    print(f"Number of samples: {cnt}")
    db.close()
