# Copyright (c) DP Techonology, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import RDLogger

RDLogger.DisableLog("rdApp.*")
import warnings

warnings.filterwarnings(action="ignore")
from rdkit.Chem import rdMolTransforms
import copy
import lmdb
import pickle
import pandas as pd


def get_torsions(m, removeHs=True):
    if removeHs:
        m = Chem.RemoveHs(m)
    torsionList = []
    torsionSmarts = "[!$(*#*)&!D1]-&!@[!$(*#*)&!D1]"
    torsionQuery = Chem.MolFromSmarts(torsionSmarts)
    matches = m.GetSubstructMatches(torsionQuery)
    for match in matches:
        idx2 = match[0]
        idx3 = match[1]
        bond = m.GetBondBetweenAtoms(idx2, idx3)
        jAtom = m.GetAtomWithIdx(idx2)
        kAtom = m.GetAtomWithIdx(idx3)
        for b1 in jAtom.GetBonds():
            if b1.GetIdx() == bond.GetIdx():
                continue
            idx1 = b1.GetOtherAtomIdx(idx2)
            for b2 in kAtom.GetBonds():
                if (b2.GetIdx() == bond.GetIdx()) or (b2.GetIdx() == b1.GetIdx()):
                    continue
                idx4 = b2.GetOtherAtomIdx(idx3)
                # skip 3-membered rings
                if idx4 == idx1:
                    continue
                # skip torsions that include hydrogens
                if (m.GetAtomWithIdx(idx1).GetAtomicNum() == 1) or (
                    m.GetAtomWithIdx(idx4).GetAtomicNum() == 1
                ):
                    continue
                if m.GetAtomWithIdx(idx4).IsInRing():
                    torsionList.append((idx4, idx3, idx2, idx1))
                    break
                else:
                    torsionList.append((idx1, idx2, idx3, idx4))
                    break
            break
    return torsionList


def SetDihedral(conf, atom_idx, new_vale):
    rdMolTransforms.SetDihedralRad(
        conf, atom_idx[0], atom_idx[1], atom_idx[2], atom_idx[3], new_vale
    )


def single_conf_gen_bonds(tgt_mol, num_confs=1000, seed=42, removeHs=True):
    mol = copy.deepcopy(tgt_mol)
    mol = Chem.AddHs(mol)
    allconformers = AllChem.EmbedMultipleConfs(
        mol, numConfs=num_confs, randomSeed=seed, clearConfs=True
    )
    if removeHs:
        mol = Chem.RemoveHs(mol)
    rotable_bonds = get_torsions(mol, removeHs=removeHs)
    for i in range(len(allconformers)):
        np.random.seed(i)
        values = 3.1415926 * 2 * np.random.rand(len(rotable_bonds))
        for idx in range(len(rotable_bonds)):
            SetDihedral(mol.GetConformers()[i], rotable_bonds[idx], values[idx])
        Chem.rdMolTransforms.CanonicalizeConformer(mol.GetConformers()[i])
    return mol


def load_lmdb_data(lmdb_path, key):
    env = lmdb.open(
        lmdb_path,
        subdir=False,
        readonly=True,
        lock=False,
        readahead=False,
        meminit=False,
        max_readers=256,
    )
    txn = env.begin()
    _keys = list(txn.cursor().iternext(values=False))
    collects = []
    for idx in range(len(_keys)):
        datapoint_pickled = txn.get(f"{idx}".encode("ascii"))
        data = pickle.loads(datapoint_pickled)
        collects.append(data[key])
    return collects


def docking_data_pre(raw_data_path, predict_path):

    mol_list = load_lmdb_data(raw_data_path, "mol_list")
    mol_list = [Chem.RemoveHs(mol) for items in mol_list for mol in items]
    print(predict_path)
    predict = pd.read_pickle(predict_path)
    (
        smi_list,
        pocket_list,
        pocket_coords_list,
        distance_predict_list,
        holo_distance_predict_list,
        holo_coords_list,
        holo_center_coords_list,
    ) = ([], [], [], [], [], [], [])
    loss = 0
    from tqdm import tqdm
    for batch in tqdm(predict):
        sz = batch["atoms"].size(0)
        #print(batch["loss"])
        loss += batch["loss"].item() * sz
        #continue
        for i in range(sz):
            #print(batch["atoms"][i], batch["holo_coordinates"][i])
            smi_list.append(batch["smi_name"][i])
            pocket_list.append(batch["pocket_name"][i])

            distance_predict = batch["cross_distance_predict"][i]
            token_mask = batch["atoms"][i] > 2
            # token mask second axis reshape to holo_coords second axis

            #token_mask = token_mask[:batch["holo_coordinates"][i].size(0)]
            pocket_token_mask = batch["pocket_atoms"][i] > 2
            #print(distance_predict.shape, token_mask.shape, pocket_token_mask.shape)
            distance_predict = distance_predict[token_mask][:, pocket_token_mask]
            pocket_coords = batch["pocket_coordinates"][i]
            pocket_coords = pocket_coords[pocket_token_mask[:batch["pocket_coordinates"][i].size(0)], :]

            holo_distance_predict = batch["holo_distance_predict"][i]
            holo_distance_predict = holo_distance_predict[token_mask][:, token_mask]

            holo_coordinates = batch["holo_coordinates"][i]
            holo_coordinates = holo_coordinates[token_mask[:batch["holo_coordinates"][i].size(0)], :]
            holo_center_coordinates = batch["holo_center_coordinates"][i][:3]

            pocket_coords = pocket_coords.numpy().astype(np.float32)
            distance_predict = distance_predict.numpy().astype(np.float32)
            holo_distance_predict = holo_distance_predict.numpy().astype(np.float32)
            holo_coords = holo_coordinates.numpy().astype(np.float32)

            pocket_coords_list.append(pocket_coords)
            distance_predict_list.append(distance_predict)
            holo_distance_predict_list.append(holo_distance_predict)
            holo_coords_list.append(holo_coords)
            holo_center_coords_list.append(holo_center_coordinates)
    print(loss)
    return (
        mol_list,
        smi_list,
        pocket_list,
        pocket_coords_list,
        distance_predict_list,
        holo_distance_predict_list,
        holo_coords_list,
        holo_center_coords_list,
    )


def ensemble_iterations(
    mol_list,
    smi_list,
    pocket_list,
    pocket_coords_list,
    distance_predict_list,
    holo_distance_predict_list,
    holo_coords_list,
    holo_center_coords_list,
    tta_times=10,
):
    sz = len(mol_list)
    for i in range(sz // tta_times):
        start_idx, end_idx = i * tta_times, (i + 1) * tta_times
        distance_predict_tta = distance_predict_list[start_idx:end_idx]
        holo_distance_predict_tta = holo_distance_predict_list[start_idx:end_idx]

        mol = copy.deepcopy(mol_list[start_idx])
        rdkit_mol = single_conf_gen_bonds(
            mol, num_confs=tta_times, seed=42, removeHs=True
        )
        sz = len(rdkit_mol.GetConformers())
        initial_coords_list = [
            rdkit_mol.GetConformers()[i].GetPositions().astype(np.float32)
            for i in range(sz)
        ]

        yield [
            initial_coords_list,
            mol,
            smi_list[start_idx],
            pocket_list[start_idx],
            pocket_coords_list[start_idx],
            distance_predict_tta,
            holo_distance_predict_tta,
            holo_coords_list[start_idx],
            holo_center_coords_list[start_idx],
        ]


def rmsd_func(holo_coords, predict_coords):
    if predict_coords is not np.nan:
        sz = holo_coords.shape
        rmsd = np.sqrt(np.sum((predict_coords - holo_coords) ** 2) / sz[0])
        return rmsd
    return 1000.0


def print_results(rmsd_results):
    print("RMSD < 1.0 : ", np.mean(rmsd_results < 1.0))
    print("RMSD < 1.5 : ", np.mean(rmsd_results < 1.5))
    print("RMSD < 2.0 : ", np.mean(rmsd_results < 2.0))
    print("RMSD < 3.0 : ", np.mean(rmsd_results < 3.0))
    print("RMSD < 5.0 : ", np.mean(rmsd_results < 5.0))
    print("avg RMSD : ", np.mean(rmsd_results))
