import rdkit
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem
from rdkit.Chem import Descriptors
import rdkit.Chem.QED as QED
import networkx as nx
import optimization.props.sascorer as sascorer
#import optimization.props.drd2_scorer as drd2_scorer
import numpy as np
import torch
from tdc import Oracle

def similarity(a, b):
    if a is None or b is None: 
        return 0.0
    amol = Chem.MolFromSmiles(a)
    bmol = Chem.MolFromSmiles(b)
    if amol is None or bmol is None:
        return 0.0

    fp1 = AllChem.GetMorganFingerprintAsBitVect(amol, 2, nBits=2048, useChirality=False)
    fp2 = AllChem.GetMorganFingerprintAsBitVect(bmol, 2, nBits=2048, useChirality=False)
    return DataStructs.TanimotoSimilarity(fp1, fp2) 

oracle = Oracle(name = 'DRD2')
def drd2(s):
    if s is None: return 0.0
    if Chem.MolFromSmiles(s) is None:
        return 0.0
    return oracle([s])[0]
    return drd2_scorer.get_score(s)

def qed(s):
    if s is None: return 0.0
    mol = Chem.MolFromSmiles(s)
    if mol is None: return 0.0
    try:
        return QED.qed(mol)
    except:
        print(f'smiles {s} is weird. not able to compute its qed')
        return 0.0

def tpsa(s):
    if s is None: return 0.0
    mol = Chem.MolFromSmiles(s)
    if mol is None: return 0.0
    return Descriptors.TPSA(mol)

# Modified from https://github.com/bowenliu16/rl_graph_generation
# TODO: should we remove stereochemistry when computing such props
def penalized_logp(s, return_normalized=False):
    if s is None: return -100.0
    mol = Chem.MolFromSmiles(s)
    if mol is None: return -100.0

    # the datalaoder will do the nomralization in our case
    logP_mean = 2.4570953396190123
    logP_std = 1.434324401111988
    SA_mean = -3.0525811293166134
    SA_std = 0.8335207024513095
    cycle_mean = -0.0485696876403053
    cycle_std = 0.2860212110245455

    log_p = Descriptors.MolLogP(mol)
    SA = -sascorer.calculateScore(mol)

    # cycle score
    cycle_list = nx.cycle_basis(nx.Graph(Chem.rdmolops.GetAdjacencyMatrix(mol)))
    if len(cycle_list) == 0:
        cycle_length = 0
    else:
        cycle_length = max([len(j) for j in cycle_list])
    if cycle_length <= 6:
        cycle_length = 0
    else:
        cycle_length = cycle_length - 6
    cycle_score = -cycle_length

    if return_normalized:
        normalized_log_p = (log_p - logP_mean) / logP_std
        normalized_SA = (SA - SA_mean) / SA_std
        normalized_cycle = (cycle_score - cycle_mean) / cycle_std
        return normalized_log_p + normalized_SA + normalized_cycle
    else:
        return log_p + SA + cycle_score

def smiles2D(s):
    mol = Chem.MolFromSmiles(s)
    return Chem.MolToSmiles(mol)

def get_morgan_fingerprint(smiles, n_bits):
    mol = Chem.MolFromSmiles(smiles)
    fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=n_bits, useChirality=False)
    fp = np.array(fp.ToList())
    return fp

def get_fp_from_binary_tensor(n_bits, t):
    t = torch.argwhere(t == 1).squeeze()
    t = t.int().tolist()
    if isinstance(t, int):
        # tolist returns an int if tensor has 1 single component
        t = [t]
    a = rdkit.DataStructs.cDataStructs.ExplicitBitVect(n_bits)
    a.SetBitsFromList(t)
    return a

def similarity_tensors(t1, t2, n_bits):
    fp1 = get_fp_from_binary_tensor(n_bits, t1)
    fp2 = get_fp_from_binary_tensor(n_bits, t2)
    return DataStructs.TanimotoSimilarity(fp1, fp2) 


if __name__ == "__main__":
    print(round(penalized_logp('ClC1=CC=C2C(C=C(C(C)=O)C(C(NC3=CC(NC(NC4=CC(C5=C(C)C=CC=C5)=CC=C4)=O)=CC=C3)=O)=C2)=C1'), 2), 5.30)
