from rdkit.Chem import MolFromSmiles, MolToSmiles
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

from time import time
import pandas as pd

from scorer.scorer import get_scores
from utils.mol_utils import get_molecular_scores, get_novelty_in_df
from moses.utils import get_mol


def get_pds(protein, csv_dir, smiles, mols=None, thrs=[0.5, 5]):
    df = pd.DataFrame()
    num_mols = len(smiles)

    # remove empty molecules
    while True:
        if '' in smiles:
            idx = smiles.index('')
            del smiles[idx]
            if mols is not None:
                del mols[idx]
        else:
            break
    df['smiles'] = smiles
    validity = len(df) / num_mols

    if mols is None:
        df['mol'] = [MolFromSmiles(s) for s in smiles]
    else:
        df['mol'] = mols

    # uniqueness = get_molecular_scores(['uni'], df['mol'], df['smiles'], 'ZINC250k')['uniqueness']
    uniqueness = len(set(df['smiles'])) / len(df)

    novelty = get_novelty_in_df(df, "ZINC250k")
    novelty_02 = len(df[df['sim'] < 0.2]) / len(df)
    novelty_03 = len(df[df['sim'] < 0.3]) / len(df)
    novelty_04 = len(df[df['sim'] < 0.4]) / len(df)

    df = df.drop_duplicates(subset=['smiles'])

    df[protein] = get_scores(protein, df['mol'])
    # df['mw'] = get_scores('mw', df['mol'])
    df['qed'] = get_scores('qed', df['mol'])
    df['sa'] = get_scores('sa', df['mol'])

    del df['mol']
    df.to_csv(f'{csv_dir}.csv', index=False)

    if protein == 'parp1': hit_thr = 10.
    elif protein == 'fa7': hit_thr = 8.5
    elif protein == '5ht1b': hit_thr = 8.7845
    elif protein == 'jak2': hit_thr = 9.1
    elif protein == 'braf': hit_thr = 10.3
    elif protein == 'tgfr1': hit_thr = 10.5
    else: raise ValueError('Wrong target protein')
    
    num_top5 = int(num_mols * 0.05)

    df = df.sort_values(by=[protein], ascending=False)
    ds = df[protein].mean(), df[protein].std()
    top_ds = df.iloc[:num_top5][protein].mean(), df.iloc[:num_top5][protein].std()

    hit = len(df[df[protein] > hit_thr]) / num_mols
    
    df = df[df['qed'] > thrs[0]]
    df = df[df['sa'] > (10 - thrs[1]) / 9]
    pass_rate = len(df) / num_mols
    
    df = df.sort_values(by=[protein], ascending=False)
    top_pass_ds = df.iloc[:num_top5][protein].mean(), df.iloc[:num_top5][protein].std()

    hit_pass = len(df[df[protein] > hit_thr]) / num_mols
    
    df = df[df['sim'] < 0.4]
    top_pass_ds_novel_04 = df.iloc[:num_top5][protein].mean(), df.iloc[:num_top5][protein].std()
    hit_novel_04 = len(df[df[protein] > hit_thr]) / num_mols
    
    df = df[df['sim'] < 0.3]
    top_pass_ds_novel_03 = df.iloc[:num_top5][protein].mean(), df.iloc[:num_top5][protein].std()
    hit_novel_03 = len(df[df[protein] > hit_thr]) / num_mols
    
    return {'validity': validity, 'uniqueness': uniqueness,
            'novelty': novelty, 'novelty_02': novelty_02, 'novelty_03': novelty_03, 'novelty_04': novelty_04,
            'pass_rate': pass_rate, 'ds': ds, 'top_ds': top_ds, 'top_pass_ds': top_pass_ds,
            'top_pass_ds_novel_03': top_pass_ds_novel_03, 'top_pass_ds_novel_04': top_pass_ds_novel_04,
            'hit': hit, 'hit_pass': hit_pass, 'hit_novel_03': hit_novel_03, 'hit_novel_04': hit_novel_04}


def get_pds_no_ds(csv_dir, smiles, mols=None, thrs=[0.5, 5]):
    df = pd.DataFrame()
    num_mols = len(smiles)

    # remove empty molecules
    while True:
        if '' in smiles:
            idx = smiles.index('')
            del smiles[idx]
            if mols is not None:
                del mols[idx]
        else:
            break
    df['smiles'] = smiles
    validity = len(df) / num_mols

    if mols is None:
        df['mol'] = [MolFromSmiles(s) for s in smiles]
    else:
        df['mol'] = mols

    uniqueness = len(set(df['smiles'])) / len(df)

    novelty = get_novelty_in_df(df, "ZINC250k")
    novelty_02 = len(df[df['sim'] < 0.2]) / len(df)
    novelty_03 = len(df[df['sim'] < 0.3]) / len(df)
    novelty_04 = len(df[df['sim'] < 0.4]) / len(df)

    df = df.drop_duplicates(subset=['smiles'])

    df['qed'] = get_scores('qed', df['mol'])
    df['sa'] = get_scores('sa', df['mol'])

    del df['mol']
    df.to_csv(f'{csv_dir}.csv', index=False)
    
    num_top5 = int(num_mols * 0.05)

    df = df[df['qed'] > thrs[0]]
    df = df[df['sa'] > (10 - thrs[1]) / 9]
    pass_rate = len(df) / num_mols
    
    return {'validity': validity, 'uniqueness': uniqueness, 'pass_rate': pass_rate,
            'novelty': novelty, 'novelty_02': novelty_02, 'novelty_03': novelty_03, 'novelty_04': novelty_04}


def get_pds_post(df, csv_dir):
    qed_thr = 0.5
    sa_thr = 5

    num_mols = len(df)
    print(f'{num_mols} molecules')

    drop_idx = []
    mols = []
    for i, smiles in enumerate(df['smiles']):
        mol = get_mol(smiles)
        if mol is None:
            drop_idx.append(i)
        else:
            mols.append(mol)
    df = df.drop(drop_idx)
    df['mol'] = mols
    print(f'{len(df)} valid molecules | Validity: {len(df) / num_mols}')

    df['smiles'] = [MolToSmiles(m) for m in df['mol']]      # canonicalize

    # print(f"Uniqueness: {get_molecular_scores(['uni'], df['mol'], df['smiles'], 'ZINC250k')['uniqueness']}")
    print(f'Uniqueness: {len(set(df["smiles"])) / len(df)}')

    novelty = get_novelty_in_df(df, "ZINC250k")
    print(f"Novelty: {novelty} | Novelty (<0.2): {len(df[df['sim'] < 0.2]) / len(df)} | "
          f"Novelty (<0.3): {len(df[df['sim'] < 0.3]) / len(df)} | Novelty (<0.4): {len(df[df['sim'] < 0.4]) / len(df)}")

    df = df.drop_duplicates(subset=['smiles'])

    if not 'qed' in df.keys():
        df['qed'] = get_scores('qed', df['mol'])

    if not 'sa' in df.keys():
        df['sa'] = get_scores('sa', df['mol'])

    del df['mol']
    df.to_csv(f'{csv_dir}.csv', index=False)

    if 'parp1' in df.keys():
        protein = 'parp1'
        hit_thr = 10.
    elif 'fa7' in df.keys():
        protein = 'fa7'
        hit_thr = 8.5
    elif '5ht1b' in df.keys():
        protein = '5ht1b'
        hit_thr = 8.7845
    elif 'jak2' in df.keys():
        protein = 'jak2'
        hit_thr = 9.1
    elif 'braf' in df.keys():
        protein = 'braf'
        hit_thr = 10.3
    elif 'tgfr1' in df.keys():
        protein = 'tgfr1'
        hit_thr = 10.5
    else:
        raise ValueError('Wrong target protein')
    
    num_top5 = int(num_mols * 0.05)

    df = df.sort_values(by=[protein], ascending=False)
    ds = df[protein].mean(), df[protein].std()
    top_ds = df.iloc[:num_top5][protein].mean(), df.iloc[:num_top5][protein].std()

    hit = len(df[df[protein] > hit_thr]) / num_mols
    
    df['tot'] = df[protein] * df['qed'] * df['sa']
    df = df.sort_values(by=['tot'], ascending=False)
    print(f'Top 3 DS x QED x SA molecules (sim. < 0.4):')
    print(f'\t{df["smiles"].iloc[0]} | {df[protein].iloc[0]} | {df["qed"].iloc[0]} | {df["sa"].iloc[0]} | {df["tot"].iloc[0]}')
    print(f'\t{df["smiles"].iloc[1]} | {df[protein].iloc[1]} | {df["qed"].iloc[1]} | {df["sa"].iloc[1]} | {df["tot"].iloc[1]}')
    print(f'\t{df["smiles"].iloc[2]} | {df[protein].iloc[2]} | {df["qed"].iloc[2]} | {df["sa"].iloc[2]} | {df["tot"].iloc[2]}')
    
    # df = df[df['mw'] < mw_thr]
    df = df[df['qed'] > qed_thr]
    df = df[df['sa'] > (10 - sa_thr) / 9]
    pass_rate = len(df) / num_mols
    
    df = df.sort_values(by=[protein], ascending=False)
    top_pass_ds = df.iloc[:num_top5][protein].mean(), df.iloc[:num_top5][protein].std()

    hit_pass = len(df[df[protein] > hit_thr]) / num_mols
    
    df = df[df['sim'] < 0.4]
    top_pass_ds_novel_04 = df.iloc[:num_top5][protein].mean(), df.iloc[:num_top5][protein].std()
    hit_novel_04 = len(df[df[protein] > hit_thr]) / num_mols
    print(f'Top 3 DS molecules (QED > {qed_thr}, SA < {sa_thr}, sim. < 0.4):')
    print(f'\t{df["smiles"].iloc[0]} | {df[protein].iloc[0]} | {df["qed"].iloc[0]} | {df["sa"].iloc[0]}')
    print(f'\t{df["smiles"].iloc[1]} | {df[protein].iloc[1]} | {df["qed"].iloc[1]} | {df["sa"].iloc[1]}')
    print(f'\t{df["smiles"].iloc[2]} | {df[protein].iloc[2]} | {df["qed"].iloc[2]} | {df["sa"].iloc[2]}')
    
    df = df[df['sim'] < 0.3]
    top_pass_ds_novel_03 = df.iloc[:num_top5][protein].mean(), df.iloc[:num_top5][protein].std()
    hit_novel_03 = len(df[df[protein] > hit_thr]) / num_mols
    
    result = {'pass_rate': pass_rate, 'ds': ds, 'top_ds': top_ds, 'top_pass_ds': top_pass_ds,
              'top_pass_ds_novel_03': top_pass_ds_novel_03, 'top_pass_ds_novel_04': top_pass_ds_novel_04,
              'hit': hit, 'hit_pass': hit_pass, 'hit_novel_03': hit_novel_03, 'hit_novel_04': hit_novel_04}
    
    print(f'DS: {result["ds"][0]:.4f} ± {result["ds"][1]:.4f}')
    print(f'top 5% DS: {result["top_ds"][0]:.4f} ± {result["top_ds"][1]:.4f}')
    print(f'pass rate (QED > {qed_thr}, SA < {sa_thr}): {result["pass_rate"]}')
    print(f'top 5% DS QED > {qed_thr}, SA < {sa_thr}): '
          f'{result["top_pass_ds"][0]:.4f} ± {result["top_pass_ds"][1]:.4f}')
    print(f'novel top 5% DS (QED > {qed_thr}, SA < {sa_thr}, sim. < 0.3): '
          f'{result["top_pass_ds_novel_03"][0]:.4f} ± {result["top_pass_ds_novel_03"][1]:.4f}')
    print(f'novel top 5% DS (QED > {qed_thr}, SA < {sa_thr}, sim. < 0.4): '
          f'{result["top_pass_ds_novel_04"][0]:.4f} ± {result["top_pass_ds_novel_04"][1]:.4f}')
    print(f'hit ratio: {result["hit"] * 100:.4f} %')
    print(f'hit ratio (QED > {qed_thr}, SA < {sa_thr}): {result["hit_pass"] * 100:.4f} %')
    print(f'novel hit ratio (QED > {qed_thr}, SA < {sa_thr}, sim. < 0.3): {result["hit_novel_03"] * 100:.4f} %')
    print(f'novel hit ratio (QED > {qed_thr}, SA < {sa_thr}, sim. < 0.4): {result["hit_novel_04"] * 100:.4f} %')
    

if __name__ == '__main__':
    results = ['freed_jak2_3', 'freed_jak2_4']

    for result in results:
        start_time = time()

        print(f'Evaluating {result}...')
        filename = f'baseline_results/{result}'
        # filename = f'logs_sample/ZINC250k/ood_prop/parp1_qed_sa_20_3000mols_sr/{result}'
        
        if 'freed' in filename or 'reinvent' in filename or 'morld' in filename or 'hier' in filename:
            if 'parp1' in filename: protein = 'parp1'
            elif 'fa7' in filename: protein = 'fa7'
            elif '5ht1b' in filename: protein = '5ht1b'
            elif 'jak2' in filename: protein = 'jak2'
            elif 'braf' in filename: protein = 'braf'
            elif 'tgfr1' in filename: protein = 'tgfr1'
            
            # col_names = ['smiles', protein, 'mw', 'pds', 'step']
            col_names = ['smiles', protein, 'step']
            # col_names = ['smiles', protein, 'sa', 'qed']    # morld
            if 'qed' in filename: col_names.insert(-2, 'qed')
            if 'sa' in filename: col_names.insert(-2, 'sa')
            if 'freed' in filename:
                df = pd.read_csv(f'{filename}.csv', header=None, names=col_names)
                df = df[df['step'] > 4000]
            elif 'hier' in filename:
                col_names = ['smiles', protein, 'qed', 'sa']
                df = pd.read_csv(f'{filename}.csv', header=None, sep=' ', names=col_names)
            else:
                df = pd.read_csv(f'{filename}.txt', header=None, names=col_names)
            if 'morld' in filename:
                df[protein] = -df[protein]
        else:
            df = pd.read_csv(f'{filename}.csv')
        
        df = df[:3000]
        get_pds_post(df, filename)

        print(f'{time() - start_time:.2f} sec elapsed')
        print('-' * 100)
