# modifed from: https://github.com/wengong-jin/hgraph2graph/blob/master/props/properties.py

import os
from shutil import rmtree
import math
from rdkit import Chem
from rdkit.Chem import Descriptors
import rdkit.Chem.QED as QED
import networkx as nx
import numpy as np

from scorer import sa_scorer
from scorer.docking_simple import get_dockingvina, make_docking_dir

from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')


def standardize_mols(mol):
    try:
        smiles = Chem.MolToSmiles(mol)
        mol = Chem.MolFromSmiles(smiles)
        return mol
    except Exception:
        print('standardize_mols error')
        return None


def standardize_smiles(mol):
    try:
        smiles = Chem.MolToSmiles(mol)
        return smiles
    except Exception:
        print('standardize_smiles error')
        return None


def get_docking_scores(target, mols, verbose=False):
    docking_dir = make_docking_dir()
    dockingvina = get_dockingvina(target, docking_dir)

    # mols_sanitized = [Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_KEKULIZE)]
    smiles = [standardize_smiles(mol) for mol in mols]
    smiles_valid = [smi for smi in smiles if smi is not None]
    
    scores = - np.array(dockingvina.predict(smiles_valid))
    if verbose:
        print(f'Number of docking errors: {sum(scores < -99)} / {len(scores)}')
    scores = list(np.clip(scores, 0, None))

    if None in smiles:
        scores = [scores.pop(0) if smi is not None else 0. for smi in smiles]

    if os.path.exists(docking_dir):
        rmtree(docking_dir)
        print(f'{docking_dir} removed.')

    return scores


def get_scores(objective, mols, standardize=True):
    if objective in ['2oh4A', 'tgfr1', 'jak2', 'braf', 'fa7', 'drd3', 'parp1', '5ht1b']:
        scores = get_docking_scores(objective, mols, True)
    else:
        if standardize:
            mols = [standardize_mols(mol) for mol in mols]
        mols_valid = [mol for mol in mols if mol is not None]
        
        scores = [get_score(objective, mol) for mol in mols_valid]
        
        if None in mols:
            scores = [scores.pop(0) if mol is not None else 0. for mol in mols]
    
    return scores


def get_score(objective, mol):
    try:
        if objective == 'qed':
            return QED.qed(mol)
        elif objective == 'sa':
            x = sa_scorer.calculateScore(mol)
            return (10. - x) / 9.   # normalized to [0, 1]
        elif objective == 'mw':     # molecular weight
            return mw(mol)
        elif objective == 'logp':   # real number
            return Descriptors.MolLogP(mol)
        elif objective == 'plogp':
            # return penalized_logp(mol)
            return calculate_min_plogp(mol)
        elif 'rand' in objective:
            raise NotImplementedError
            # return rand_scorer.get_score(objective, mol)
        else: raise NotImplementedError
    # except ValueError:
    #####
    except (ValueError, ZeroDivisionError):
        return 0.

    
### molecular properties
def mw(mol):
    '''
    molecular weight estimation from qed
    '''
    x = Descriptors.MolWt(mol)
    a, b, c, d, e, f = 2.817, 392.575, 290.749, 2.420, 49.223, 65.371
    g = math.exp(-(x - c + d/2) / e)
    h = math.exp(-(x - c - d/2) / f)
    x = a + b / (1 + g) * (1 - 1 / (1 + h))
    return x / 104.981


def calculate_min_plogp(mol):
    """
    Calculate the eward that consists of log p penalized by SA and # long cycles,
    as described in (Kusner et al. 2017). Scores are normalized based on the
    statistics of 250k_rndm_zinc_drugs_clean.smi dataset.
    """
    p1 = penalized_logp(mol)
    s1 = Chem.MolToSmiles(mol, isomericSmiles=True)
    s2 = Chem.MolToSmiles(mol, isomericSmiles=False)
    mol1 = Chem.MolFromSmiles(s1)
    mol2 = Chem.MolFromSmiles(s2)    
    p2 = penalized_logp(mol1)
    p3 = penalized_logp(mol2)    
    final_p = min(p1, p2)
    final_p = min(final_p, p3)
    return final_p


def penalized_logp(mol):
    # Modified from https://github.com/bowenliu16/rl_graph_generation
    logP_mean = 2.4570953396190123
    logP_std = 1.434324401111988
    SA_mean = -3.0525811293166134
    SA_std = 0.8335207024513095
    cycle_mean = -0.0485696876403053
    cycle_std = 0.2860212110245455

    log_p = Descriptors.MolLogP(mol)
    SA = -sa_scorer.calculateScore(mol)

    # cycle score
    cycle_list = nx.cycle_basis(nx.Graph(Chem.rdmolops.GetAdjacencyMatrix(mol)))
    if len(cycle_list) == 0:
        cycle_length = 0
    else:
        cycle_length = max([len(j) for j in cycle_list])
    if cycle_length <= 6:
        cycle_length = 0
    else:
        cycle_length = cycle_length - 6
    cycle_score = -cycle_length

    normalized_log_p = (log_p - logP_mean) / logP_std
    normalized_SA = (SA - SA_mean) / SA_std
    normalized_cycle = (cycle_score - cycle_mean) / cycle_std
    return normalized_log_p + normalized_SA + normalized_cycle
