import torch
import numpy as np
import pandas as pd
import multiprocessing as mp
from torch_geometric.data import Data
from functools import partial 
from easydict import EasyDict
from tqdm.auto import tqdm
from rdkit import Chem
from rdkit.Chem.rdForceFieldHelpers import MMFFOptimizeMolecule

# from ..chem import set_rdmol_positions, get_best_rmsd
from rdkit.Chem.rdmolops import RemoveHs
from copy import deepcopy
from rdkit.Chem import rdMolAlign as MA


def set_rdmol_positions(rdkit_mol, pos):
    """
    Args:
        rdkit_mol:  An `rdkit.Chem.rdchem.Mol` object.
        pos: (N_atoms, 3)
    """
    mol = deepcopy(rdkit_mol)
    set_rdmol_positions_(mol, pos)
    return mol


def set_rdmol_positions_(mol, pos):
    """
    Args:
        rdkit_mol:  An `rdkit.Chem.rdchem.Mol` object.
        pos: (N_atoms, 3)
    """
    for i in range(pos.shape[0]):
        mol.GetConformer(0).SetAtomPosition(i, pos[i].tolist())
    return mol
def get_best_rmsd(probe, ref):
    probe = RemoveHs(probe)
    ref = RemoveHs(ref)
    rmsd = MA.GetBestRMS(probe, ref)
    return rmsd


def get_rmsd_confusion_matrix(data: Data, useFF=False):
    data['pos_ref'] = data['pos_ref'].reshape(-1, data['rdmol'].GetNumAtoms(), 3)
    data['pos_gen'] = data['pos_gen'].reshape(-1, data['rdmol'].GetNumAtoms(), 3)
    num_gen = data['pos_gen'].shape[0]
    num_ref = data['pos_ref'].shape[0]

    # assert num_gen == data.num_pos_gen.item()
    # assert num_ref == data.num_pos_ref.item()

    rmsd_confusion_mat = -1 * np.ones([num_ref, num_gen],dtype=np.float)
    if useFF:
        print('Applying FF on generated molecules...')
    for i in range(num_gen):
        gen_mol = set_rdmol_positions(data['rdmol'], data['pos_gen'][i])
        if useFF:
            #print('Applying FF on generated molecules...')
            MMFFOptimizeMolecule(gen_mol)
            # print('FF')
        for j in range(num_ref):
            ref_mol = set_rdmol_positions(data['rdmol'], data['pos_ref'][j])
            
            rmsd_confusion_mat[j,i] = get_best_rmsd(gen_mol, ref_mol)
            if (rmsd_confusion_mat>100).any():
                print(1)

    return rmsd_confusion_mat
    

def evaluate_conf(data: Data, useFF=False, threshold=0.5):
    rmsd_confusion_mat = get_rmsd_confusion_matrix(data, useFF=useFF)
    rmsd_ref_min = rmsd_confusion_mat.min(-1)
    #print('done one mol')
    #print(rmsd_ref_min)
    return (rmsd_ref_min<=threshold).mean(), rmsd_ref_min.mean()


# def print_covmat_results(results, print_fn=print):
#     df = pd.DataFrame({
#         'COV-R_mean': np.mean(results.CoverageR, 0),
#         'COV-R_median': np.median(results.CoverageR, 0),
#         'COV-R_std': np.std(results.CoverageR, 0),
#         'COV-P_mean': np.mean(results.CoverageP, 0),
#         'COV-P_median': np.median(results.CoverageP, 0),
#         'COV-P_std': np.std(results.CoverageP, 0),
#     }, index=results.thresholds)
#     print_fn('\n' + str(df))
#     print_fn('MAT-R_mean: %.4f | MAT-R_median: %.4f | MAT-R_std %.4f' % (
#         np.mean(results.MatchingR), np.median(results.MatchingR), np.std(results.MatchingR)
#     ))
#     print_fn('MAT-P_mean: %.4f | MAT-P_median: %.4f | MAT-P_std %.4f' % (
#         np.mean(results.MatchingP), np.median(results.MatchingP), np.std(results.MatchingP)
#     ))
#     return df
def print_covmat_results(results, print_fn=print):
    df_cov = pd.DataFrame({
        'COV-R_mean': np.mean(results.CoverageR, 0),
        'COV-R_median': np.median(results.CoverageR, 0),
        'COV-R_std': np.std(results.CoverageR, 0),
        'COV-P_mean': np.mean(results.CoverageP, 0),
        'COV-P_median': np.median(results.CoverageP, 0),
        'COV-P_std': np.std(results.CoverageP, 0),
    }, index=results.thresholds)
    print_fn('\n' + str(df_cov))

    df_mat = pd.DataFrame({
        'MAT-R_mean': np.mean(results.MatchingR),
        'MAT-R_median': np.median(results.MatchingR),
        'MAT-R_std': np.std(results.MatchingR),
        'MAT-P_mean': np.mean(results.MatchingP),
        'MAT-P_median': np.median(results.MatchingP),
        'MAT-P_std': np.std(results.MatchingP),
    }, index=[0])
    print_fn('\n' + str(df_mat))



    # df_cov05mat = pd.DataFrame({
    #     'MAT-R_mean': np.mean(df_cov.loc[0.5]['COV-R_mean']),
    #     'MAT-R_median': np.median(df_cov.loc[0.5].MatchingR),
    #     'MAT-R_std': np.std(df_cov.loc[0.5].MatchingR),
    #     'MAT-P_mean': np.mean(df_cov.loc[0.5].MatchingP),
    #     'MAT-P_median': np.median(df_cov.loc[0.5].MatchingP),
    #     'MAT-P_std': np.std(df_cov.loc[0.5].MatchingP),
    # }, index=list(range(6)))

    return df_cov, df_mat



class CovMatEvaluator(object):

    def __init__(self, 
        num_workers=1, 
        use_force_field=False, 
        thresholds=np.arange(0.05, 3.05, 0.05),
        ratio=2,
        filter_disconnected=True,
        print_fn=print,
    ):
        super().__init__()
        self.num_workers = num_workers
        self.use_force_field = use_force_field
        self.thresholds = np.array(thresholds).flatten()
        
        self.ratio = ratio
        self.filter_disconnected = filter_disconnected
        
        self.pool = mp.Pool(num_workers)
        self.print_fn = print_fn

    def __call__(self, packed_data_list, start_idx=0):
        func = partial(get_rmsd_confusion_matrix, useFF=self.use_force_field)
        
        filtered_data_list = []
        for data in packed_data_list:
            if 'pos_gen' not in data or 'pos_ref' not in data: continue
            if self.filter_disconnected and ('.' in data['smiles']): continue
            
            data['pos_ref'] = data['pos_ref'].reshape(-1, data['rdmol'].GetNumAtoms(), 3)
            data['pos_gen'] = data['pos_gen'].reshape(-1, data['rdmol'].GetNumAtoms(), 3)

            num_gen = data['pos_ref'].shape[0] * self.ratio
            if data['pos_gen'].shape[0] < num_gen: continue
            data['pos_gen'] = data['pos_gen'][:num_gen]

            filtered_data_list.append(data)

        filtered_data_list = filtered_data_list[start_idx:]
        self.print_fn('Filtered: %d / %d' % (len(filtered_data_list), len(packed_data_list)))

        covr_scores = []
        matr_scores = []
        covp_scores = []
        matp_scores = []
        for confusion_mat in tqdm(self.pool.imap(func, filtered_data_list), total=len(filtered_data_list)):
            # confusion_mat: (num_ref, num_gen)
            rmsd_ref_min = confusion_mat.min(-1)    # np (num_ref, )
            rmsd_gen_min = confusion_mat.min(0)     # np (num_gen, )
            rmsd_cov_thres = rmsd_ref_min.reshape(-1, 1) <= self.thresholds.reshape(1, -1)  # np (num_ref, num_thres)
            rmsd_jnk_thres = rmsd_gen_min.reshape(-1, 1) <= self.thresholds.reshape(1, -1) # np (num_gen, num_thres)

            matr_scores.append(rmsd_ref_min.mean())
            covr_scores.append(rmsd_cov_thres.mean(0, keepdims=True))    # np (1, num_thres)
            matp_scores.append(rmsd_gen_min.mean())
            covp_scores.append(rmsd_jnk_thres.mean(0, keepdims=True))    # np (1, num_thres)

        covr_scores = np.vstack(covr_scores)  # np (num_mols, num_thres)
        matr_scores = np.array(matr_scores)   # np (num_mols, )
        covp_scores = np.vstack(covp_scores)  # np (num_mols, num_thres)
        matp_scores = np.array(matp_scores)

        results = EasyDict({
            'CoverageR': covr_scores,
            'MatchingR': matr_scores,
            'thresholds': self.thresholds,
            'CoverageP': covp_scores,
            'MatchingP': matp_scores
        })
        # print_conformation_eval_results(results)
        return results