import lmdb
import numpy as np
import torch
from rdkit import Chem
import random
import pickle
from torch.utils.data import Dataset
from torch.nn import functional as F
from torch_geometric.data import Data
import os
from tqdm import tqdm
from open_biomed.data import Molecule
from open_biomed.utils.config import Config

class DBReader:
    def __init__(self, path, affinity_path=None) -> None:
        self.path = path
        self.affinity_path = affinity_path
        self.db = None
        self.keys = None
        self.affinity_info = None

    def _connect_db(self):
        """
            Establish read-only database connection
        """
        assert self.db is None, 'A connection has already been opened.'
        self.db = lmdb.open(
            self.path,
            map_size=10*(1024*1024*1024),   # 10GB
            create=False,
            subdir=False,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False,
        )
        with self.db.begin() as txn:
            self.keys = list(txn.cursor().iternext(values=False))

    def _close_db(self):
        self.db.close()
        self.db = None
        self.keys = None

    def __del__(self):
        if self.db is not None:
            self._close_db()

    def __len__(self):
        if self.db is None:
            self._connect_db()
        return len(self.keys)

    def __getitem__(self, idx):
        if self.db is None:
            self._connect_db()
        key = self.keys[idx]
        data = pickle.loads(self.db.begin().get(key))
        # if self.affinity_path is not None:
        #     if 'affinity' not in data:
        #         self._load_affinity_info()
        #         self._inject_affinity(key, data['ligand_filename'])
        #         data = pickle.loads(self.db.begin().get(key))
        return data
    
    def _update(self, sid, affinity):
        if self.db is None:
            self._connect_db()
        txn = self.db.begin(write=True)
        data = pickle.loads(txn.get(sid))
        data.update({
            'affinity': affinity['vina'],
            'rmsd': affinity['rmsd'],
            'pk': affinity['pk'],
        })
        txn.put(
            key=sid,
            value=pickle.dumps(data)
        )
        txn.commit()

    def _load_affinity_info(self):
        if self.affinity_info is not None:
            return
        if os.path.exists(self.affinity_path):
            with open(self.affinity_path, 'rb') as f:
                affinity_info = pickle.load(f)
        else:
            raise FileNotFoundError(f'Affinity info not found at {self.affinity_path}')
            affinity_info = {}
            with open(self.raw_affinity_path, 'r') as f:
                for ln in tqdm(f.readlines()):
                    # <label> <pK> <RMSD to crystal> <Receptor> <Ligand> # <Autodock Vina score>
                    label, pk, rmsd, protein_fn, ligand_fn, vina = ln.split()
                    ligand_raw_fn = ligand_fn[:ligand_fn.rfind('.')]
                    affinity_info[ligand_raw_fn] = {
                        'label': float(label),
                        'rmsd': float(rmsd),
                        'pk': float(pk),
                        'vina': float(vina[1:])
                    }
            # save affinity info
            with open(self.affinity_path, 'wb') as f:
                pickle.dump(affinity_info, f)
        
        self.affinity_info = affinity_info

    def _inject_affinity(self, sid, ligand_path):
        if ligand_path[:-4] in self.affinity_info:
            affinity = self.affinity_info[ligand_path[:-4]]
            self._update(sid, affinity)
        else:
            raise AttributeError(f'affinity_info has no {ligand_path[:-4]}')

class CSDBySample(Dataset):
    def __init__(self, config: Config, exp_name: str=None):
        self.reader = DBReader(config.path)
        self.exp_name = exp_name
        
    def __len__(self):
        return len(self.reader)
    
    def __getitem__(self, idx):
        data = self.reader[idx]
        if "cfg" in self.exp_name:
            classifier_input = data["labels"]
            if classifier_input.dtype is torch.long and random.random() < 0.1:
                classifier_input = torch.ones(3, dtype=torch.long) * 9
            return {
                "pocket": data["pocket"],
                "molecule": data["molecule"],
                "classifier_input": classifier_input,
            }
        else:
            return data
    
MAP_ATOM_TYPE_AROMATIC_TO_INDEX = {
    (1, False): 0,
    (6, False): 1,
    (6, True): 2,
    (7, False): 3,
    (7, True): 4,
    (8, False): 5,
    (8, True): 6,
    (9, False): 7,
    (15, False): 8,
    (15, True): 9,
    (16, False): 10,
    (16, True): 11,
    (17, False): 12
}
        
class CSDFromMolJo(Dataset):
    def __init__(self, config: Config):
        self.reader = DBReader(config.path)
        self.atomic_numbers = torch.LongTensor([1, 6, 7, 8, 16, 34])  # H, C, N, O, S, Se
        self.max_num_aa = 20
        
    def __len__(self):
        return len(self.reader)
    
    def __getitem__(self, idx):
        data = self.reader[idx]
        element = data["protein_element"].view(-1, 1) == self.atomic_numbers.view(1, -1)  # (N_atoms, N_elements)
        amino_acid = F.one_hot(data["protein_atom_to_aa_type"], num_classes=self.max_num_aa)
        is_backbone = data["protein_is_backbone"].view(-1, 1).long()
        center = torch.mean(data["protein_pos"], dim=0)

        molecule_atom_feature = []
        for i in range(len(data["ligand_element"])):
            key = (int(data["ligand_element"][i]), bool(data["ligand_atom_feature"][i][2]))
            if key not in MAP_ATOM_TYPE_AROMATIC_TO_INDEX:
                molecule_atom_feature.append(0)
            else:
                molecule_atom_feature.append(MAP_ATOM_TYPE_AROMATIC_TO_INDEX[key])
        molecule = Data(**{
            "atom_feature": F.one_hot(torch.tensor(molecule_atom_feature, dtype=torch.long), num_classes=len(MAP_ATOM_TYPE_AROMATIC_TO_INDEX)),
            "pos": (data["ligand_pos"] - center) / 2.0,
        })
        pocket = Data(**{
            "atom_feature": torch.cat([element, amino_acid, is_backbone], dim=-1).float(),
            "pos": (data["protein_pos"] - center) / 2.0,
        })
        mol = Molecule.from_smiles(data["ligand_smiles"])
        data["qed"] = mol.calc_qed()
        data["sa"] = mol.calc_sa()

        return {
            "molecule": molecule,
            "pocket": pocket,
            "labels": torch.tensor([-np.clip(data["affinity"], -16, 0) / 16, (np.clip(data["qed"], 0.01, 0.95) - 0.01) / (0.95 - 0.01), (np.clip(data["sa"], 0.17, 1.0) - 0.17) / (1.0 - 0.17)])
        }
        