import gzip
import pickle
import rdkit.DataStructs
from rdkit import Chem
import numpy as np
from time import time


def get_tanimoto_pairwise(mols):
    fps = [Chem.RDKFingerprint(i.mol) for i in mols]
    pairwise_sim = []
    for i in range(len(mols)):
        # pairwise_sim.extend(rdkit.DataStructs.BulkTanimotoSimilarity(fps[i], fps[:i]+fps[i+1:]))
        pairwise_sim.extend(rdkit.DataStructs.BulkTanimotoSimilarity(fps[i], fps[i+1:]))
    return pairwise_sim

class NumModes:
    def __init__(self, reward_exp, reward_norm, reward_thr=8, tanimoto_thr=0.7):
        self.reward_exp = reward_exp
        self.reward_norm = reward_norm
        self.reward_thr = reward_thr
        self.tanimoto_thr = tanimoto_thr
        self.modes = []
        self.max_reward = -1000
    def __call__(self, batch):
        candidates = []
        for some in batch:
            reward, mol = some[0], some[1]
            reward = (reward ** (1/self.reward_exp)) * self.reward_norm
            if reward > self.max_reward: 
                self.max_reward = reward
            if reward > self.reward_thr:
                candidates.append(mol)
        if len(candidates) > 0:
            # add one mode if needed
            if len(self.modes)==0: 
                self.modes.append(Chem.RDKFingerprint(candidates[0].mol))
            for mol in candidates:
                fp = Chem.RDKFingerprint(mol.mol)
                sims = np.asarray(rdkit.DataStructs.BulkTanimotoSimilarity(fp, self.modes))
                if all(sims < self.tanimoto_thr):
                    self.modes.append(fp)
        return self.max_reward, len(self.modes)

# def compute_diversity(data, reward_exp, reward_norm, reward_thr, skip=100):
#     # with gzip.open(run_path, 'rb') as f:
#         # data = pickle.load(f)
#     numModes = NumModes(reward_exp, reward_norm, reward_thr=reward_thr)
#     diversity = []
#     mean_tanimoto_sim = []
#     num_modes = []
#     for i in range(len(data)//256):  # 4000
#         if not i % skip:
#             batch = data[i * 256: ((i + 1) * 256)]
#             mols = [x[1] for x in batch]
#             tm = get_tanimoto_pairwise(mols) # length = 256 * 255
#             diversity.append(np.mean(np.array(tm) < 0.75))
#             mean_tanimoto_sim.append(np.mean(tm))
#             max_reward, nm = numModes(batch)
#             num_modes.append(nm)
#             print(f"i={i} max_reward={max_reward:.3f} num of modes={nm}")
#     # x = np.linspace(start=0, stop=1000000, num=len(num_modes)) # ?
#     # return x, diversity, mean_tanimoto_sim, num_modes
#     return diversity, mean_tanimoto_sim, num_modes


def eval_mols(mols, reward_norm=8, reward_exp=10, algo="gfn"):
    def r2r_back(r):
        return r ** (1. / reward_exp) * reward_norm
    
    # def summarize_stats(unique_rs, top_k):
    #     top_k_rs = unique_rs[:min(top_k, len(unique_rs))]
    #     avg_top_k_rs = sum(top_k_rs) / len(top_k_rs)
    #     return avg_top_k_rs

    numModes_above_7_5 = NumModes(reward_exp=reward_exp, reward_norm=reward_norm, reward_thr=7.5)
    _, num_modes_above_7_5 = numModes_above_7_5(mols)
    numModes_above_8_0 = NumModes(reward_exp=reward_exp, reward_norm=reward_norm, reward_thr=8.)
    _, num_modes_above_8_0 = numModes_above_8_0(mols)

    top_ks = [10, 100, 1000]
    avg_topk_rs = {}
    avg_topk_tanimoto = {}
    mol_r_map = {}

    for i in range(len(mols)):
        if algo == 'gfn':
            r, m, trajectory_stats, inflow = mols[i]
        else:
            r, m = mols[i]
        r = r2r_back(r)
        mol_r_map[m] = r
    
    unique_rs = list(mol_r_map.values())
    unique_rs = sorted(unique_rs, reverse=True)
    unique_rs = np.array(unique_rs)
    num_above_7_5 = np.sum(unique_rs > 7.5) # just a integer
    num_above_8_0 = np.sum(unique_rs > 8.0)

    sorted_mol_r_map = sorted(mol_r_map.items(), key=lambda kv: kv[1], reverse=True)
    for top_k_idx, top_k in enumerate(top_ks):
        avg_topk_rs[top_k] = np.mean(unique_rs[:top_k])
        
        topk_mols = [mol for (mol, r) in sorted_mol_r_map[:top_k]]
        avg_topk_tanimoto[top_k] = np.mean(get_tanimoto_pairwise(topk_mols))

    return avg_topk_rs, avg_topk_tanimoto, num_modes_above_7_5, num_modes_above_8_0, num_above_7_5, num_above_8_0


if __name__ == '__main__':
    import h5py
    f = h5py.File("./data/docked_mols.h5", "r")
    print(list(f.keys()))
    dset = f['df']
    # 'axis0', 'axis1', 'block0_items', 'block0_values', 'block1_items', 'block1_values'
    print(dset['axis0'])
    import ipdb; ipdb.set_trace()

    mol_path = f"./results/105000_sampled_mols.pkl.gz"
    mols = pickle.load(gzip.open(mol_path, 'rb'))
    # diversity, mean_tanimoto_sim, num_modes = \
    #     compute_diversity(mols, reward_exp=10, reward_norm=8,reward_thr=8.)

    t1 = time()
    avg_topk_rs, avg_topk_tanimoto, num_modes_above_7_5, num_modes_above_8_0, num_above_7_5, num_above_8_0 = eval_mols(mols)
    print(f"{time() - t1:.2f} sec")
    print(avg_topk_rs, avg_topk_tanimoto, num_modes_above_7_5, num_modes_above_8_0, num_above_7_5, num_above_8_0)
    quit()

    numModes = NumModes(reward_exp=10, reward_norm=8, reward_thr=7.5)
    t1 = time()
    max_reward, num_modes = numModes(mols)
    print(f"{time() - t1:.2f} sec")
    import ipdb; ipdb.set_trace()

    m1 = mols[0][1]
    m2 = mols[1][1]
    # compute similarity between 1 and 10000 mols needs 15 sec
    rdkit.DataStructs.BulkTanimotoSimilarity(Chem.RDKFingerprint(mols[0][1].mol), [Chem.RDKFingerprint(m[1].mol) for m in mols[1:1000]])
    rdkit.DataStructs.BulkTanimotoSimilarity(Chem.RDKFingerprint(m1.mol), [Chem.RDKFingerprint(m2.mol),])
    rdkit.DataStructs.TanimotoSimilarity(Chem.RDKFingerprint(m1.mol), Chem.RDKFingerprint(m2.mol)) # same
    from rdkit.Chem import AllChem
    rdkit.DataStructs.TanimotoSimilarity(AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(m1.smiles), 3, nBits=2048), AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(m2.smiles), 3, nBits=2048))
    rdkit.DataStructs.TanimotoSimilarity(AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(m1.smiles), 10, nBits=2048), AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(m2.smiles), 10, nBits=2048))

    import ipdb; ipdb.set_trace()