import os
import sys
import copy
import numpy as np
import random
import torch
import pandas as pd
import pdb

from rdkit import Chem
from rdkit.Chem import AllChem, QED, RDConfig, rdFingerprintGenerator
from rdkit.Chem.rdmolops import FastFindRings
sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
import sascorer
from rdkit.Chem import Crippen

from . import scorer


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.random.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def reward_residue_interaction(smis, predictor):
    affinity_list, _, interact_num_list = predictor.predict_with_interactions(smis)
    rew_aff, rew_interact = - np.array(affinity_list), np.array(interact_num_list)
    rew_aff = np.clip(rew_aff, 0, None)
    return rew_aff, rew_interact

def reward_vina(smis, predictor):
    reward = - np.array(predictor.predict(smis))
    reward = np.clip(reward, 0, None)
    return reward


def reward_qed(mols):
    return [QED.qed(m) for m in mols]


def reward_sa(mols):
    return [(10 - sascorer.calculateScore(m)) / 9 for m in mols]


def reward_logp(mols):
    logp_mean = 2.4570953396190123
    logp_std = 1.434324401111988
    return [(Crippen.MolLogP(m) - logp_mean) / logp_std for m in mols]

def reward_ring(mols):
    ring_mean = -0.0347
    ring_std = 0.03509591
    ring_mols = [Chem.GetSymmSSSR(m) for m in mols]
    ring_score = []
    for rings in ring_mols:
        ring_score.append(-sum([len(ring) > 6 for ring in rings]))

    return [(s - ring_mean) / ring_std for s in ring_score]

def reward_guacamol_mpo(mols, name='osim', scale=True):
    ''' MPO task names : ['osimertinib', 'fexofenadine', 'ranolazine', 'perindopril',
                        'amlodipine', 'zaleplon', 'sitagliptin']
    '''

    if hasattr(scorer, name):
        func = getattr(scorer, name)
    else:
        raise ValueError(f"Task {name} is not supported!")
    return [func(m, scale) for m in mols]


def get_att_points(mol):
    att_points = []
    for a in mol.GetAtoms():
        if a.GetSymbol() == '*':
            att_points.append(a.GetIdx())
    return att_points


def delete_multiple_element(list_object, indices):
    indices = sorted(indices, reverse=True)
    for idx in indices:
        if idx < len(list_object):
            list_object.pop(idx)


def get_vocab(vocab_path):
    global ATOM
    
    ATOM = ['C', 'N', 'O', 'S', 'P', 'F', 'I', 'Cl','Br', '*']
    df = pd.read_csv(vocab_path, names=['frag', 'score'])
    FRAG = df['frag'].tolist()
    FRAG_QUEUE = list(zip(df['frag'], df['score']))
    FRAG_MOL = [Chem.MolFromSmiles(s) for s in FRAG]
    FRAG_ATT = [get_att_points(m) for m in FRAG_MOL]
    return {'ATOM': ATOM, 'FRAG': FRAG, 'FRAG_QUEUE': FRAG_QUEUE, 'FRAG_MOL': FRAG_MOL, 'FRAG_ATT': FRAG_ATT}


def ecfp(molecule):
    molecule = Chem.DeleteSubstructs(molecule, Chem.MolFromSmiles("*"))
    molecule.UpdatePropertyCache() # this step is necessary after removing dummy atoms
    FastFindRings(molecule) # Accelerate rings detection when batched process of molecules
    mfpgen = rdFingerprintGenerator.GetMorganGenerator(radius=2, fpSize=1024)
    fp = mfpgen.GetFingerprint(molecule)
    return [x for x in fp]


def one_of_k_encoding_unk(x, allowable_set):
    """Maps inputs not in the allowable set to the last element."""
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: float(x == s), allowable_set))


def atom_feature(atom, use_atom_meta):
    ATOM = ['C', 'N', 'O', 'S', 'P', 'F', 'I', 'Cl','Br', '*']
    if use_atom_meta == False:
        return np.asarray(
            one_of_k_encoding_unk(atom.GetSymbol(), ATOM) 
            )
    else:
        return np.asarray(
            one_of_k_encoding_unk(atom.GetSymbol(), ATOM) +
            one_of_k_encoding_unk(atom.GetDegree(), [0, 1, 2, 3, 4, 5]) +
            one_of_k_encoding_unk(atom.GetTotalNumHs(), [0, 1, 2, 3, 4]) +
            one_of_k_encoding_unk(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5]) +
            [atom.GetIsAromatic()])


def convert_radical_electrons_to_hydrogens(mol):
    m = copy.deepcopy(mol)
    if Chem.Descriptors.NumRadicalElectrons(m) == 0:  # not a radical
        return m
    else:  # a radical
        for a in m.GetAtoms():
            num_radical_e = a.GetNumRadicalElectrons()
            if num_radical_e > 0:
                a.SetNumRadicalElectrons(0)
                a.SetNumExplicitHs(num_radical_e)
    return m
